package sqlite3
// Backup is an handle to an ongoing online backup operation.
//
// https://sqlite.org/c3ref/backup.html
type Backup struct {
c *Conn
handle ptr_t
otherc ptr_t
}
// Backup backs up srcDB on the src connection to the "main" database in dstURI.
//
// Backup opens the SQLite database file dstURI,
// and blocks until the entire backup is complete.
// Use [Conn.BackupInit] for incremental backup.
//
// https://sqlite.org/backup.html
func (src *Conn) Backup(srcDB, dstURI string) error {
b, err := src.BackupInit(srcDB, dstURI)
if err != nil {
return err
}
defer b.Close()
_, err = b.Step(-1)
return err
}
// Restore restores dstDB on the dst connection from the "main" database in srcURI.
//
// Restore opens the SQLite database file srcURI,
// and blocks until the entire restore is complete.
//
// https://sqlite.org/backup.html
func (dst *Conn) Restore(dstDB, srcURI string) error {
src, err := dst.openDB(srcURI, OPEN_READONLY|OPEN_URI)
if err != nil {
return err
}
b, err := dst.backupInit(dst.handle, dstDB, src, "main")
if err != nil {
return err
}
defer b.Close()
_, err = b.Step(-1)
return err
}
// BackupInit initializes a backup operation to copy the content of one database into another.
//
// BackupInit opens the SQLite database file dstURI,
// then initializes a backup that copies the contents of srcDB on the src connection
// to the "main" database in dstURI.
//
// https://sqlite.org/c3ref/backup_finish.html#sqlite3backupinit
func (src *Conn) BackupInit(srcDB, dstURI string) (*Backup, error) {
dst, err := src.openDB(dstURI, OPEN_READWRITE|OPEN_CREATE|OPEN_URI)
if err != nil {
return nil, err
}
return src.backupInit(dst, "main", src.handle, srcDB)
}
func (c *Conn) backupInit(dst ptr_t, dstName string, src ptr_t, srcName string) (*Backup, error) {
defer c.arena.mark()()
dstPtr := c.arena.string(dstName)
srcPtr := c.arena.string(srcName)
other := dst
if c.handle == dst {
other = src
}
ptr := ptr_t(c.call("sqlite3_backup_init",
stk_t(dst), stk_t(dstPtr),
stk_t(src), stk_t(srcPtr)))
if ptr == 0 {
defer c.closeDB(other)
rc := res_t(c.call("sqlite3_errcode", stk_t(dst)))
return nil, c.sqlite.error(rc, dst)
}
return &Backup{
c: c,
otherc: other,
handle: ptr,
}, nil
}
// Close finishes a backup operation.
//
// It is safe to close a nil, zero or closed Backup.
//
// https://sqlite.org/c3ref/backup_finish.html#sqlite3backupfinish
func (b *Backup) Close() error {
if b == nil || b.handle == 0 {
return nil
}
rc := res_t(b.c.call("sqlite3_backup_finish", stk_t(b.handle)))
b.c.closeDB(b.otherc)
b.handle = 0
return b.c.error(rc)
}
// Step copies up to nPage pages between the source and destination databases.
// If nPage is negative, all remaining source pages are copied.
//
// https://sqlite.org/c3ref/backup_finish.html#sqlite3backupstep
func (b *Backup) Step(nPage int) (done bool, err error) {
rc := res_t(b.c.call("sqlite3_backup_step", stk_t(b.handle), stk_t(nPage)))
if rc == _DONE {
return true, nil
}
return false, b.c.error(rc)
}
// Remaining returns the number of pages still to be backed up
// at the conclusion of the most recent [Backup.Step].
//
// https://sqlite.org/c3ref/backup_finish.html#sqlite3backupremaining
func (b *Backup) Remaining() int {
n := int32(b.c.call("sqlite3_backup_remaining", stk_t(b.handle)))
return int(n)
}
// PageCount returns the total number of pages in the source database
// at the conclusion of the most recent [Backup.Step].
//
// https://sqlite.org/c3ref/backup_finish.html#sqlite3backuppagecount
func (b *Backup) PageCount() int {
n := int32(b.c.call("sqlite3_backup_pagecount", stk_t(b.handle)))
return int(n)
}
package sqlite3
import (
"io"
"github.com/ncruces/go-sqlite3/internal/util"
)
// ZeroBlob represents a zero-filled, length n BLOB
// that can be used as an argument to
// [database/sql.DB.Exec] and similar methods.
type ZeroBlob int64
// Blob is an handle to an open BLOB.
//
// It implements [io.ReadWriteSeeker] for incremental BLOB I/O.
//
// https://sqlite.org/c3ref/blob.html
type Blob struct {
c *Conn
bytes int64
offset int64
handle ptr_t
bufptr ptr_t
buflen int64
}
var _ io.ReadWriteSeeker = &Blob{}
// OpenBlob opens a BLOB for incremental I/O.
//
// https://sqlite.org/c3ref/blob_open.html
func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, error) {
if c.interrupt.Err() != nil {
return nil, INTERRUPT
}
defer c.arena.mark()()
blobPtr := c.arena.new(ptrlen)
dbPtr := c.arena.string(db)
tablePtr := c.arena.string(table)
columnPtr := c.arena.string(column)
var flags int32
if write {
flags = 1
}
rc := res_t(c.call("sqlite3_blob_open", stk_t(c.handle),
stk_t(dbPtr), stk_t(tablePtr), stk_t(columnPtr),
stk_t(row), stk_t(flags), stk_t(blobPtr)))
if err := c.error(rc); err != nil {
return nil, err
}
blob := Blob{c: c}
blob.handle = util.Read32[ptr_t](c.mod, blobPtr)
blob.bytes = int64(int32(c.call("sqlite3_blob_bytes", stk_t(blob.handle))))
return &blob, nil
}
// Close closes a BLOB handle.
//
// It is safe to close a nil, zero or closed Blob.
//
// https://sqlite.org/c3ref/blob_close.html
func (b *Blob) Close() error {
if b == nil || b.handle == 0 {
return nil
}
rc := res_t(b.c.call("sqlite3_blob_close", stk_t(b.handle)))
b.c.free(b.bufptr)
b.handle = 0
return b.c.error(rc)
}
// Size returns the size of the BLOB in bytes.
//
// https://sqlite.org/c3ref/blob_bytes.html
func (b *Blob) Size() int64 {
return b.bytes
}
// Read implements the [io.Reader] interface.
//
// https://sqlite.org/c3ref/blob_read.html
func (b *Blob) Read(p []byte) (n int, err error) {
if b.offset >= b.bytes {
return 0, io.EOF
}
want := int64(len(p))
avail := b.bytes - b.offset
if want > avail {
want = avail
}
if want > b.buflen {
b.bufptr = b.c.realloc(b.bufptr, want)
b.buflen = want
}
rc := res_t(b.c.call("sqlite3_blob_read", stk_t(b.handle),
stk_t(b.bufptr), stk_t(want), stk_t(b.offset)))
err = b.c.error(rc)
if err != nil {
return 0, err
}
b.offset += want
if b.offset >= b.bytes {
err = io.EOF
}
copy(p, util.View(b.c.mod, b.bufptr, want))
return int(want), err
}
// WriteTo implements the [io.WriterTo] interface.
//
// https://sqlite.org/c3ref/blob_read.html
func (b *Blob) WriteTo(w io.Writer) (n int64, err error) {
if b.offset >= b.bytes {
return 0, nil
}
want := int64(1024 * 1024)
avail := b.bytes - b.offset
if want > avail {
want = avail
}
if want > b.buflen {
b.bufptr = b.c.realloc(b.bufptr, want)
b.buflen = want
}
for want > 0 {
rc := res_t(b.c.call("sqlite3_blob_read", stk_t(b.handle),
stk_t(b.bufptr), stk_t(want), stk_t(b.offset)))
err = b.c.error(rc)
if err != nil {
return n, err
}
mem := util.View(b.c.mod, b.bufptr, want)
m, err := w.Write(mem[:want])
b.offset += int64(m)
n += int64(m)
if err != nil {
return n, err
}
if int64(m) != want {
// notest // Write misbehaving
return n, io.ErrShortWrite
}
avail = b.bytes - b.offset
if want > avail {
want = avail
}
}
return n, nil
}
// Write implements the [io.Writer] interface.
//
// https://sqlite.org/c3ref/blob_write.html
func (b *Blob) Write(p []byte) (n int, err error) {
want := int64(len(p))
if want > b.buflen {
b.bufptr = b.c.realloc(b.bufptr, want)
b.buflen = want
}
util.WriteBytes(b.c.mod, b.bufptr, p)
rc := res_t(b.c.call("sqlite3_blob_write", stk_t(b.handle),
stk_t(b.bufptr), stk_t(want), stk_t(b.offset)))
err = b.c.error(rc)
if err != nil {
return 0, err
}
b.offset += int64(len(p))
return len(p), nil
}
// ReadFrom implements the [io.ReaderFrom] interface.
//
// https://sqlite.org/c3ref/blob_write.html
func (b *Blob) ReadFrom(r io.Reader) (n int64, err error) {
want := int64(1024 * 1024)
avail := b.bytes - b.offset
if l, ok := r.(*io.LimitedReader); ok && want > l.N {
want = l.N
}
if want > avail {
want = avail
}
if want < 1 {
want = 1
}
if want > b.buflen {
b.bufptr = b.c.realloc(b.bufptr, want)
b.buflen = want
}
for {
mem := util.View(b.c.mod, b.bufptr, want)
m, err := r.Read(mem[:want])
if m > 0 {
rc := res_t(b.c.call("sqlite3_blob_write", stk_t(b.handle),
stk_t(b.bufptr), stk_t(m), stk_t(b.offset)))
err := b.c.error(rc)
if err != nil {
return n, err
}
b.offset += int64(m)
n += int64(m)
}
if err == io.EOF {
return n, nil
}
if err != nil {
return n, err
}
avail = b.bytes - b.offset
if want > avail {
want = avail
}
if want < 1 {
want = 1
}
}
}
// Seek implements the [io.Seeker] interface.
func (b *Blob) Seek(offset int64, whence int) (int64, error) {
switch whence {
default:
return 0, util.WhenceErr
case io.SeekStart:
break
case io.SeekCurrent:
offset += b.offset
case io.SeekEnd:
offset += b.bytes
}
if offset < 0 {
return 0, util.OffsetErr
}
b.offset = offset
return offset, nil
}
// Reopen moves a BLOB handle to a new row of the same database table.
//
// https://sqlite.org/c3ref/blob_reopen.html
func (b *Blob) Reopen(row int64) error {
if b.c.interrupt.Err() != nil {
return INTERRUPT
}
err := b.c.error(res_t(b.c.call("sqlite3_blob_reopen", stk_t(b.handle), stk_t(row))))
b.bytes = int64(int32(b.c.call("sqlite3_blob_bytes", stk_t(b.handle))))
b.offset = 0
return err
}
package sqlite3
import (
"context"
"fmt"
"strconv"
"sync/atomic"
"github.com/tetratelabs/wazero/api"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/go-sqlite3/vfs"
)
// Config makes configuration changes to a database connection.
// Only boolean configuration options are supported.
// Called with no arg reads the current configuration value,
// called with one arg sets and returns the new value.
//
// https://sqlite.org/c3ref/db_config.html
func (c *Conn) Config(op DBConfig, arg ...bool) (bool, error) {
if op < DBCONFIG_ENABLE_FKEY || op > DBCONFIG_REVERSE_SCANORDER {
return false, MISUSE
}
// We need to call sqlite3_db_config, a variadic function.
// We only support the `int int*` variants.
// The int is a three-valued bool: -1 queries, 0/1 sets false/true.
// The int* points to where new state will be written to.
// The vararg is a pointer to an array containing these arguments:
// an int and an int* pointing to that int.
defer c.arena.mark()()
argsPtr := c.arena.new(intlen + ptrlen)
var flag int32
switch {
case len(arg) == 0:
flag = -1
case arg[0]:
flag = 1
}
util.Write32(c.mod, argsPtr+0*ptrlen, flag)
util.Write32(c.mod, argsPtr+1*ptrlen, argsPtr)
rc := res_t(c.call("sqlite3_db_config", stk_t(c.handle),
stk_t(op), stk_t(argsPtr)))
return util.ReadBool(c.mod, argsPtr), c.error(rc)
}
var defaultLogger atomic.Pointer[func(code ExtendedErrorCode, msg string)]
// ConfigLog sets up the default error logging callback for new connections.
//
// https://sqlite.org/errlog.html
func ConfigLog(cb func(code ExtendedErrorCode, msg string)) {
defaultLogger.Store(&cb)
}
// ConfigLog sets up the error logging callback for the connection.
//
// https://sqlite.org/errlog.html
func (c *Conn) ConfigLog(cb func(code ExtendedErrorCode, msg string)) error {
var enable int32
if cb != nil {
enable = 1
}
rc := res_t(c.call("sqlite3_config_log_go", stk_t(enable)))
if err := c.error(rc); err != nil {
return err
}
c.log = cb
return nil
}
func logCallback(ctx context.Context, mod api.Module, _ ptr_t, iCode res_t, zMsg ptr_t) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.log != nil {
msg := util.ReadString(mod, zMsg, _MAX_LENGTH)
c.log(xErrorCode(iCode), msg)
}
}
// Log writes a message into the error log established by [Conn.ConfigLog].
//
// https://sqlite.org/c3ref/log.html
func (c *Conn) Log(code ExtendedErrorCode, format string, a ...any) {
if c.log != nil {
c.log(code, fmt.Sprintf(format, a...))
}
}
// FileControl allows low-level control of database files.
// Only a subset of opcodes are supported.
//
// https://sqlite.org/c3ref/file_control.html
func (c *Conn) FileControl(schema string, op FcntlOpcode, arg ...any) (any, error) {
defer c.arena.mark()()
ptr := c.arena.new(max(ptrlen, intlen))
var schemaPtr ptr_t
if schema != "" {
schemaPtr = c.arena.string(schema)
}
var rc res_t
var ret any
switch op {
default:
return nil, MISUSE
case FCNTL_RESET_CACHE:
rc = res_t(c.call("sqlite3_file_control",
stk_t(c.handle), stk_t(schemaPtr),
stk_t(op), 0))
case FCNTL_PERSIST_WAL, FCNTL_POWERSAFE_OVERWRITE:
var flag int32
switch {
case len(arg) == 0:
flag = -1
case arg[0]:
flag = 1
}
util.Write32(c.mod, ptr, flag)
rc = res_t(c.call("sqlite3_file_control",
stk_t(c.handle), stk_t(schemaPtr),
stk_t(op), stk_t(ptr)))
ret = util.ReadBool(c.mod, ptr)
case FCNTL_CHUNK_SIZE:
util.Write32(c.mod, ptr, int32(arg[0].(int)))
rc = res_t(c.call("sqlite3_file_control",
stk_t(c.handle), stk_t(schemaPtr),
stk_t(op), stk_t(ptr)))
case FCNTL_RESERVE_BYTES:
bytes := -1
if len(arg) > 0 {
bytes = arg[0].(int)
}
util.Write32(c.mod, ptr, int32(bytes))
rc = res_t(c.call("sqlite3_file_control",
stk_t(c.handle), stk_t(schemaPtr),
stk_t(op), stk_t(ptr)))
ret = int(util.Read32[int32](c.mod, ptr))
case FCNTL_DATA_VERSION:
rc = res_t(c.call("sqlite3_file_control",
stk_t(c.handle), stk_t(schemaPtr),
stk_t(op), stk_t(ptr)))
ret = util.Read32[uint32](c.mod, ptr)
case FCNTL_LOCKSTATE:
rc = res_t(c.call("sqlite3_file_control",
stk_t(c.handle), stk_t(schemaPtr),
stk_t(op), stk_t(ptr)))
ret = util.Read32[vfs.LockLevel](c.mod, ptr)
case FCNTL_VFS_POINTER:
rc = res_t(c.call("sqlite3_file_control",
stk_t(c.handle), stk_t(schemaPtr),
stk_t(op), stk_t(ptr)))
if rc == _OK {
const zNameOffset = 16
ptr = util.Read32[ptr_t](c.mod, ptr)
ptr = util.Read32[ptr_t](c.mod, ptr+zNameOffset)
name := util.ReadString(c.mod, ptr, _MAX_NAME)
ret = vfs.Find(name)
}
case FCNTL_FILE_POINTER, FCNTL_JOURNAL_POINTER:
rc = res_t(c.call("sqlite3_file_control",
stk_t(c.handle), stk_t(schemaPtr),
stk_t(op), stk_t(ptr)))
if rc == _OK {
const fileHandleOffset = 4
ptr = util.Read32[ptr_t](c.mod, ptr)
ptr = util.Read32[ptr_t](c.mod, ptr+fileHandleOffset)
ret = util.GetHandle(c.ctx, ptr)
}
}
if err := c.error(rc); err != nil {
return nil, err
}
return ret, nil
}
// Limit allows the size of various constructs to be
// limited on a connection by connection basis.
//
// https://sqlite.org/c3ref/limit.html
func (c *Conn) Limit(id LimitCategory, value int) int {
v := int32(c.call("sqlite3_limit", stk_t(c.handle), stk_t(id), stk_t(value)))
return int(v)
}
// SetAuthorizer registers an authorizer callback with the database connection.
//
// https://sqlite.org/c3ref/set_authorizer.html
func (c *Conn) SetAuthorizer(cb func(action AuthorizerActionCode, name3rd, name4th, schema, inner string) AuthorizerReturnCode) error {
var enable int32
if cb != nil {
enable = 1
}
rc := res_t(c.call("sqlite3_set_authorizer_go", stk_t(c.handle), stk_t(enable)))
if err := c.error(rc); err != nil {
return err
}
c.authorizer = cb
return nil
}
func authorizerCallback(ctx context.Context, mod api.Module, pDB ptr_t, action AuthorizerActionCode, zName3rd, zName4th, zSchema, zInner ptr_t) (rc AuthorizerReturnCode) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.authorizer != nil {
var name3rd, name4th, schema, inner string
if zName3rd != 0 {
name3rd = util.ReadString(mod, zName3rd, _MAX_NAME)
}
if zName4th != 0 {
name4th = util.ReadString(mod, zName4th, _MAX_NAME)
}
if zSchema != 0 {
schema = util.ReadString(mod, zSchema, _MAX_NAME)
}
if zInner != 0 {
inner = util.ReadString(mod, zInner, _MAX_NAME)
}
rc = c.authorizer(action, name3rd, name4th, schema, inner)
}
return rc
}
// Trace registers a trace callback function against the database connection.
//
// https://sqlite.org/c3ref/trace_v2.html
func (c *Conn) Trace(mask TraceEvent, cb func(evt TraceEvent, arg1 any, arg2 any) error) error {
rc := res_t(c.call("sqlite3_trace_go", stk_t(c.handle), stk_t(mask)))
if err := c.error(rc); err != nil {
return err
}
c.trace = cb
return nil
}
func traceCallback(ctx context.Context, mod api.Module, evt TraceEvent, pDB, pArg1, pArg2 ptr_t) (rc res_t) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.trace != nil {
var arg1, arg2 any
if evt == TRACE_CLOSE {
arg1 = c
} else {
for _, s := range c.stmts {
if pArg1 == s.handle {
arg1 = s
switch evt {
case TRACE_STMT:
arg2 = s.SQL()
case TRACE_PROFILE:
arg2 = util.Read64[int64](mod, pArg2)
}
break
}
}
}
if arg1 != nil {
_, rc = errorCode(c.trace(evt, arg1, arg2), ERROR)
}
}
return rc
}
// WALCheckpoint checkpoints a WAL database.
//
// https://sqlite.org/c3ref/wal_checkpoint_v2.html
func (c *Conn) WALCheckpoint(schema string, mode CheckpointMode) (nLog, nCkpt int, err error) {
if c.interrupt.Err() != nil {
return 0, 0, INTERRUPT
}
defer c.arena.mark()()
nLogPtr := c.arena.new(ptrlen)
nCkptPtr := c.arena.new(ptrlen)
schemaPtr := c.arena.string(schema)
rc := res_t(c.call("sqlite3_wal_checkpoint_v2",
stk_t(c.handle), stk_t(schemaPtr), stk_t(mode),
stk_t(nLogPtr), stk_t(nCkptPtr)))
nLog = int(util.Read32[int32](c.mod, nLogPtr))
nCkpt = int(util.Read32[int32](c.mod, nCkptPtr))
return nLog, nCkpt, c.error(rc)
}
// WALAutoCheckpoint configures WAL auto-checkpoints.
//
// https://sqlite.org/c3ref/wal_autocheckpoint.html
func (c *Conn) WALAutoCheckpoint(pages int) error {
rc := res_t(c.call("sqlite3_wal_autocheckpoint", stk_t(c.handle), stk_t(pages)))
return c.error(rc)
}
// WALHook registers a callback function to be invoked
// each time data is committed to a database in WAL mode.
//
// https://sqlite.org/c3ref/wal_hook.html
func (c *Conn) WALHook(cb func(db *Conn, schema string, pages int) error) {
var enable int32
if cb != nil {
enable = 1
}
c.call("sqlite3_wal_hook_go", stk_t(c.handle), stk_t(enable))
c.wal = cb
}
func walCallback(ctx context.Context, mod api.Module, _, pDB, zSchema ptr_t, pages int32) (rc res_t) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.wal != nil {
schema := util.ReadString(mod, zSchema, _MAX_NAME)
err := c.wal(c, schema, int(pages))
_, rc = errorCode(err, ERROR)
}
return rc
}
// AutoVacuumPages registers a autovacuum compaction amount callback.
//
// https://sqlite.org/c3ref/autovacuum_pages.html
func (c *Conn) AutoVacuumPages(cb func(schema string, dbPages, freePages, bytesPerPage uint) uint) error {
var funcPtr ptr_t
if cb != nil {
funcPtr = util.AddHandle(c.ctx, cb)
}
rc := res_t(c.call("sqlite3_autovacuum_pages_go", stk_t(c.handle), stk_t(funcPtr)))
return c.error(rc)
}
func autoVacuumCallback(ctx context.Context, mod api.Module, pApp, zSchema ptr_t, nDbPage, nFreePage, nBytePerPage uint32) uint32 {
fn := util.GetHandle(ctx, pApp).(func(schema string, dbPages, freePages, bytesPerPage uint) uint)
schema := util.ReadString(mod, zSchema, _MAX_NAME)
return uint32(fn(schema, uint(nDbPage), uint(nFreePage), uint(nBytePerPage)))
}
// SoftHeapLimit imposes a soft limit on heap size.
//
// https://sqlite.org/c3ref/hard_heap_limit64.html
func (c *Conn) SoftHeapLimit(n int64) int64 {
return int64(c.call("sqlite3_soft_heap_limit64", stk_t(n)))
}
// HardHeapLimit imposes a hard limit on heap size.
//
// https://sqlite.org/c3ref/hard_heap_limit64.html
func (c *Conn) HardHeapLimit(n int64) int64 {
return int64(c.call("sqlite3_hard_heap_limit64", stk_t(n)))
}
// EnableChecksums enables checksums on a database.
//
// https://sqlite.org/cksumvfs.html
func (c *Conn) EnableChecksums(schema string) error {
r, err := c.FileControl(schema, FCNTL_RESERVE_BYTES)
if err != nil {
return err
}
if r == 8 {
// Correct value, enabled.
return nil
}
if r == 0 {
// Default value, enable.
_, err = c.FileControl(schema, FCNTL_RESERVE_BYTES, 8)
if err != nil {
return err
}
r, err = c.FileControl(schema, FCNTL_RESERVE_BYTES)
if err != nil {
return err
}
}
if r != 8 {
// Invalid value.
return util.ErrorString("sqlite3: reserve bytes must be 8, is: " + strconv.Itoa(r.(int)))
}
// VACUUM the database.
if schema != "" {
err = c.Exec(`VACUUM ` + QuoteIdentifier(schema))
} else {
err = c.Exec(`VACUUM`)
}
if err != nil {
return err
}
// Checkpoint the WAL.
_, _, err = c.WALCheckpoint(schema, CHECKPOINT_FULL)
return err
}
package sqlite3
import (
"context"
"fmt"
"iter"
"math"
"math/rand"
"net/url"
"runtime"
"strings"
"time"
"github.com/tetratelabs/wazero/api"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/go-sqlite3/vfs"
)
// Conn is a database connection handle.
// A Conn is not safe for concurrent use by multiple goroutines.
//
// https://sqlite.org/c3ref/sqlite3.html
type Conn struct {
*sqlite
interrupt context.Context
stmts []*Stmt
busy func(context.Context, int) bool
log func(xErrorCode, string)
collation func(*Conn, string)
wal func(*Conn, string, int) error
trace func(TraceEvent, any, any) error
authorizer func(AuthorizerActionCode, string, string, string, string) AuthorizerReturnCode
update func(AuthorizerActionCode, string, string, int64)
commit func() bool
rollback func()
busy1st time.Time
busylst time.Time
arena arena
handle ptr_t
gosched uint8
}
// Open calls [OpenFlags] with [OPEN_READWRITE], [OPEN_CREATE] and [OPEN_URI].
func Open(filename string) (*Conn, error) {
return newConn(context.Background(), filename, OPEN_READWRITE|OPEN_CREATE|OPEN_URI)
}
// OpenContext is like [Open] but includes a context,
// which is used to interrupt the process of opening the connection.
func OpenContext(ctx context.Context, filename string) (*Conn, error) {
return newConn(ctx, filename, OPEN_READWRITE|OPEN_CREATE|OPEN_URI)
}
// OpenFlags opens an SQLite database file as specified by the filename argument.
//
// If none of the required flags are used, a combination of [OPEN_READWRITE] and [OPEN_CREATE] is used.
// If a URI filename is used, PRAGMA statements to execute can be specified using "_pragma":
//
// sqlite3.Open("file:demo.db?_pragma=busy_timeout(10000)")
//
// https://sqlite.org/c3ref/open.html
func OpenFlags(filename string, flags OpenFlag) (*Conn, error) {
if flags&(OPEN_READONLY|OPEN_READWRITE|OPEN_CREATE) == 0 {
flags |= OPEN_READWRITE | OPEN_CREATE
}
return newConn(context.Background(), filename, flags)
}
type connKey = util.ConnKey
func newConn(ctx context.Context, filename string, flags OpenFlag) (ret *Conn, _ error) {
err := ctx.Err()
if err != nil {
return nil, err
}
c := &Conn{interrupt: ctx}
c.sqlite, err = instantiateSQLite()
if err != nil {
return nil, err
}
defer func() {
if ret == nil {
c.Close()
c.sqlite.close()
} else {
c.interrupt = context.Background()
}
}()
c.ctx = context.WithValue(c.ctx, connKey{}, c)
if logger := defaultLogger.Load(); logger != nil {
c.ConfigLog(*logger)
}
c.arena = c.newArena()
c.handle, err = c.openDB(filename, flags)
if err == nil {
err = initExtensions(c)
}
if err != nil {
return nil, err
}
return c, nil
}
func (c *Conn) openDB(filename string, flags OpenFlag) (ptr_t, error) {
defer c.arena.mark()()
connPtr := c.arena.new(ptrlen)
namePtr := c.arena.string(filename)
flags |= OPEN_EXRESCODE
rc := res_t(c.call("sqlite3_open_v2", stk_t(namePtr), stk_t(connPtr), stk_t(flags), 0))
handle := util.Read32[ptr_t](c.mod, connPtr)
if err := c.sqlite.error(rc, handle); err != nil {
c.closeDB(handle)
return 0, err
}
c.call("sqlite3_progress_handler_go", stk_t(handle), 1000)
if flags|OPEN_URI != 0 && strings.HasPrefix(filename, "file:") {
var pragmas strings.Builder
if _, after, ok := strings.Cut(filename, "?"); ok {
query, _ := url.ParseQuery(after)
for _, p := range query["_pragma"] {
pragmas.WriteString(`PRAGMA `)
pragmas.WriteString(p)
pragmas.WriteString(`;`)
}
}
if pragmas.Len() != 0 {
pragmaPtr := c.arena.string(pragmas.String())
rc := res_t(c.call("sqlite3_exec", stk_t(handle), stk_t(pragmaPtr), 0, 0, 0))
if err := c.sqlite.error(rc, handle, pragmas.String()); err != nil {
err = fmt.Errorf("sqlite3: invalid _pragma: %w", err)
c.closeDB(handle)
return 0, err
}
}
}
return handle, nil
}
func (c *Conn) closeDB(handle ptr_t) {
rc := res_t(c.call("sqlite3_close_v2", stk_t(handle)))
if err := c.sqlite.error(rc, handle); err != nil {
panic(err)
}
}
// Close closes the database connection.
//
// If the database connection is associated with unfinalized prepared statements,
// open blob handles, and/or unfinished backup objects,
// Close will leave the database connection open and return [BUSY].
//
// It is safe to close a nil, zero or closed Conn.
//
// https://sqlite.org/c3ref/close.html
func (c *Conn) Close() error {
if c == nil || c.handle == 0 {
return nil
}
rc := res_t(c.call("sqlite3_close", stk_t(c.handle)))
if err := c.error(rc); err != nil {
return err
}
c.handle = 0
return c.close()
}
// Exec is a convenience function that allows an application to run
// multiple statements of SQL without having to use a lot of code.
//
// https://sqlite.org/c3ref/exec.html
func (c *Conn) Exec(sql string) error {
if c.interrupt.Err() != nil {
return INTERRUPT
}
return c.exec(sql)
}
func (c *Conn) exec(sql string) error {
defer c.arena.mark()()
textPtr := c.arena.string(sql)
rc := res_t(c.call("sqlite3_exec", stk_t(c.handle), stk_t(textPtr), 0, 0, 0))
return c.error(rc, sql)
}
// Prepare calls [Conn.PrepareFlags] with no flags.
func (c *Conn) Prepare(sql string) (stmt *Stmt, tail string, err error) {
return c.PrepareFlags(sql, 0)
}
// PrepareFlags compiles the first SQL statement in sql;
// tail is left pointing to what remains uncompiled.
// If the input text contains no SQL (if the input is an empty string or a comment),
// both stmt and err will be nil.
//
// https://sqlite.org/c3ref/prepare.html
func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail string, err error) {
if len(sql) > _MAX_SQL_LENGTH {
return nil, "", TOOBIG
}
if c.interrupt.Err() != nil {
return nil, "", INTERRUPT
}
defer c.arena.mark()()
stmtPtr := c.arena.new(ptrlen)
tailPtr := c.arena.new(ptrlen)
textPtr := c.arena.string(sql)
rc := res_t(c.call("sqlite3_prepare_v3", stk_t(c.handle),
stk_t(textPtr), stk_t(len(sql)+1), stk_t(flags),
stk_t(stmtPtr), stk_t(tailPtr)))
stmt = &Stmt{c: c, sql: sql}
stmt.handle = util.Read32[ptr_t](c.mod, stmtPtr)
if sql := sql[util.Read32[ptr_t](c.mod, tailPtr)-textPtr:]; sql != "" {
tail = sql
}
if err := c.error(rc, sql); err != nil {
return nil, "", err
}
if stmt.handle == 0 {
return nil, "", nil
}
c.stmts = append(c.stmts, stmt)
return stmt, tail, nil
}
// DBName returns the schema name for n-th database on the database connection.
//
// https://sqlite.org/c3ref/db_name.html
func (c *Conn) DBName(n int) string {
ptr := ptr_t(c.call("sqlite3_db_name", stk_t(c.handle), stk_t(n)))
if ptr == 0 {
return ""
}
return util.ReadString(c.mod, ptr, _MAX_NAME)
}
// Filename returns the filename for a database.
//
// https://sqlite.org/c3ref/db_filename.html
func (c *Conn) Filename(schema string) *vfs.Filename {
var ptr ptr_t
if schema != "" {
defer c.arena.mark()()
ptr = c.arena.string(schema)
}
ptr = ptr_t(c.call("sqlite3_db_filename", stk_t(c.handle), stk_t(ptr)))
return vfs.GetFilename(c.ctx, c.mod, ptr, vfs.OPEN_MAIN_DB)
}
// ReadOnly determines if a database is read-only.
//
// https://sqlite.org/c3ref/db_readonly.html
func (c *Conn) ReadOnly(schema string) (ro bool, ok bool) {
var ptr ptr_t
if schema != "" {
defer c.arena.mark()()
ptr = c.arena.string(schema)
}
b := int32(c.call("sqlite3_db_readonly", stk_t(c.handle), stk_t(ptr)))
return b > 0, b < 0
}
// GetAutocommit tests the connection for auto-commit mode.
//
// https://sqlite.org/c3ref/get_autocommit.html
func (c *Conn) GetAutocommit() bool {
b := int32(c.call("sqlite3_get_autocommit", stk_t(c.handle)))
return b != 0
}
// LastInsertRowID returns the rowid of the most recent successful INSERT
// on the database connection.
//
// https://sqlite.org/c3ref/last_insert_rowid.html
func (c *Conn) LastInsertRowID() int64 {
return int64(c.call("sqlite3_last_insert_rowid", stk_t(c.handle)))
}
// SetLastInsertRowID allows the application to set the value returned by
// [Conn.LastInsertRowID].
//
// https://sqlite.org/c3ref/set_last_insert_rowid.html
func (c *Conn) SetLastInsertRowID(id int64) {
c.call("sqlite3_set_last_insert_rowid", stk_t(c.handle), stk_t(id))
}
// Changes returns the number of rows modified, inserted or deleted
// by the most recently completed INSERT, UPDATE or DELETE statement
// on the database connection.
//
// https://sqlite.org/c3ref/changes.html
func (c *Conn) Changes() int64 {
return int64(c.call("sqlite3_changes64", stk_t(c.handle)))
}
// TotalChanges returns the number of rows modified, inserted or deleted
// by all INSERT, UPDATE or DELETE statements completed
// since the database connection was opened.
//
// https://sqlite.org/c3ref/total_changes.html
func (c *Conn) TotalChanges() int64 {
return int64(c.call("sqlite3_total_changes64", stk_t(c.handle)))
}
// ReleaseMemory frees memory used by a database connection.
//
// https://sqlite.org/c3ref/db_release_memory.html
func (c *Conn) ReleaseMemory() error {
rc := res_t(c.call("sqlite3_db_release_memory", stk_t(c.handle)))
return c.error(rc)
}
// GetInterrupt gets the context set with [Conn.SetInterrupt].
func (c *Conn) GetInterrupt() context.Context {
return c.interrupt
}
// SetInterrupt interrupts a long-running query when a context is done.
//
// Subsequent uses of the connection will return [INTERRUPT]
// until the context is reset by another call to SetInterrupt.
//
// To associate a timeout with a connection:
//
// ctx, cancel := context.WithTimeout(context.TODO(), 100*time.Millisecond)
// conn.SetInterrupt(ctx)
// defer cancel()
//
// SetInterrupt returns the old context assigned to the connection.
//
// https://sqlite.org/c3ref/interrupt.html
func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
if ctx == nil {
panic("nil Context")
}
old = c.interrupt
c.interrupt = ctx
return old
}
func progressCallback(ctx context.Context, mod api.Module, _ ptr_t) (interrupt int32) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok {
if c.gosched++; c.gosched%16 == 0 {
runtime.Gosched()
}
if c.interrupt.Err() != nil {
interrupt = 1
}
}
return interrupt
}
// BusyTimeout sets a busy timeout.
//
// https://sqlite.org/c3ref/busy_timeout.html
func (c *Conn) BusyTimeout(timeout time.Duration) error {
ms := min((timeout+time.Millisecond-1)/time.Millisecond, math.MaxInt32)
rc := res_t(c.call("sqlite3_busy_timeout", stk_t(c.handle), stk_t(ms)))
return c.error(rc)
}
func timeoutCallback(ctx context.Context, mod api.Module, count, tmout int32) (retry int32) {
// https://fractaledmind.github.io/2024/04/15/sqlite-on-rails-the-how-and-why-of-optimal-performance/
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.interrupt.Err() == nil {
switch {
case count == 0:
c.busy1st = time.Now()
case time.Since(c.busy1st) >= time.Duration(tmout)*time.Millisecond:
return 0
}
if time.Since(c.busylst) < time.Millisecond {
const sleepIncrement = 2*1024*1024 - 1 // power of two, ~2ms
time.Sleep(time.Duration(rand.Int63() & sleepIncrement))
}
c.busylst = time.Now()
return 1
}
return 0
}
// BusyHandler registers a callback to handle [BUSY] errors.
//
// https://sqlite.org/c3ref/busy_handler.html
func (c *Conn) BusyHandler(cb func(ctx context.Context, count int) (retry bool)) error {
var enable int32
if cb != nil {
enable = 1
}
rc := res_t(c.call("sqlite3_busy_handler_go", stk_t(c.handle), stk_t(enable)))
if err := c.error(rc); err != nil {
return err
}
c.busy = cb
return nil
}
func busyCallback(ctx context.Context, mod api.Module, pDB ptr_t, count int32) (retry int32) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.busy != nil {
if interrupt := c.interrupt; interrupt.Err() == nil &&
c.busy(interrupt, int(count)) {
retry = 1
}
}
return retry
}
// Status retrieves runtime status information about a database connection.
//
// https://sqlite.org/c3ref/db_status.html
func (c *Conn) Status(op DBStatus, reset bool) (current, highwater int, err error) {
defer c.arena.mark()()
hiPtr := c.arena.new(intlen)
curPtr := c.arena.new(intlen)
var i int32
if reset {
i = 1
}
rc := res_t(c.call("sqlite3_db_status", stk_t(c.handle),
stk_t(op), stk_t(curPtr), stk_t(hiPtr), stk_t(i)))
if err = c.error(rc); err == nil {
current = int(util.Read32[int32](c.mod, curPtr))
highwater = int(util.Read32[int32](c.mod, hiPtr))
}
return
}
// TableColumnMetadata extracts metadata about a column of a table.
//
// https://sqlite.org/c3ref/table_column_metadata.html
func (c *Conn) TableColumnMetadata(schema, table, column string) (declType, collSeq string, notNull, primaryKey, autoInc bool, err error) {
defer c.arena.mark()()
var schemaPtr, columnPtr ptr_t
declTypePtr := c.arena.new(ptrlen)
collSeqPtr := c.arena.new(ptrlen)
notNullPtr := c.arena.new(ptrlen)
autoIncPtr := c.arena.new(ptrlen)
primaryKeyPtr := c.arena.new(ptrlen)
if schema != "" {
schemaPtr = c.arena.string(schema)
}
tablePtr := c.arena.string(table)
if column != "" {
columnPtr = c.arena.string(column)
}
rc := res_t(c.call("sqlite3_table_column_metadata", stk_t(c.handle),
stk_t(schemaPtr), stk_t(tablePtr), stk_t(columnPtr),
stk_t(declTypePtr), stk_t(collSeqPtr),
stk_t(notNullPtr), stk_t(primaryKeyPtr), stk_t(autoIncPtr)))
if err = c.error(rc); err == nil && column != "" {
if ptr := util.Read32[ptr_t](c.mod, declTypePtr); ptr != 0 {
declType = util.ReadString(c.mod, ptr, _MAX_NAME)
}
if ptr := util.Read32[ptr_t](c.mod, collSeqPtr); ptr != 0 {
collSeq = util.ReadString(c.mod, ptr, _MAX_NAME)
}
notNull = util.ReadBool(c.mod, notNullPtr)
autoInc = util.ReadBool(c.mod, autoIncPtr)
primaryKey = util.ReadBool(c.mod, primaryKeyPtr)
}
return
}
func (c *Conn) error(rc res_t, sql ...string) error {
return c.sqlite.error(rc, c.handle, sql...)
}
// Stmts returns an iterator for the prepared statements
// associated with the database connection.
//
// https://sqlite.org/c3ref/next_stmt.html
func (c *Conn) Stmts() iter.Seq[*Stmt] {
return func(yield func(*Stmt) bool) {
for _, s := range c.stmts {
if !yield(s) {
break
}
}
}
}
package sqlite3
import (
"strconv"
"github.com/ncruces/go-sqlite3/internal/util"
)
const (
_OK = 0 /* Successful result */
_ROW = 100 /* sqlite3_step() has another row ready */
_DONE = 101 /* sqlite3_step() has finished executing */
_MAX_NAME = 1e6 // Self-imposed limit for most NUL terminated strings.
_MAX_LENGTH = 1e9
_MAX_SQL_LENGTH = 1e9
ptrlen = util.PtrLen
intlen = util.IntLen
)
type (
stk_t = util.Stk_t
ptr_t = util.Ptr_t
res_t = util.Res_t
)
// ErrorCode is a result code that [Error.Code] might return.
//
// https://sqlite.org/rescode.html
type ErrorCode uint8
const (
ERROR ErrorCode = 1 /* Generic error */
INTERNAL ErrorCode = 2 /* Internal logic error in SQLite */
PERM ErrorCode = 3 /* Access permission denied */
ABORT ErrorCode = 4 /* Callback routine requested an abort */
BUSY ErrorCode = 5 /* The database file is locked */
LOCKED ErrorCode = 6 /* A table in the database is locked */
NOMEM ErrorCode = 7 /* A malloc() failed */
READONLY ErrorCode = 8 /* Attempt to write a readonly database */
INTERRUPT ErrorCode = 9 /* Operation terminated by sqlite3_interrupt() */
IOERR ErrorCode = 10 /* Some kind of disk I/O error occurred */
CORRUPT ErrorCode = 11 /* The database disk image is malformed */
NOTFOUND ErrorCode = 12 /* Unknown opcode in sqlite3_file_control() */
FULL ErrorCode = 13 /* Insertion failed because database is full */
CANTOPEN ErrorCode = 14 /* Unable to open the database file */
PROTOCOL ErrorCode = 15 /* Database lock protocol error */
EMPTY ErrorCode = 16 /* Internal use only */
SCHEMA ErrorCode = 17 /* The database schema changed */
TOOBIG ErrorCode = 18 /* String or BLOB exceeds size limit */
CONSTRAINT ErrorCode = 19 /* Abort due to constraint violation */
MISMATCH ErrorCode = 20 /* Data type mismatch */
MISUSE ErrorCode = 21 /* Library used incorrectly */
NOLFS ErrorCode = 22 /* Uses OS features not supported on host */
AUTH ErrorCode = 23 /* Authorization denied */
FORMAT ErrorCode = 24 /* Not used */
RANGE ErrorCode = 25 /* 2nd parameter to sqlite3_bind out of range */
NOTADB ErrorCode = 26 /* File opened that is not a database file */
NOTICE ErrorCode = 27 /* Notifications from sqlite3_log() */
WARNING ErrorCode = 28 /* Warnings from sqlite3_log() */
)
// ExtendedErrorCode is a result code that [Error.ExtendedCode] might return.
//
// https://sqlite.org/rescode.html
type (
ExtendedErrorCode uint16
xErrorCode = ExtendedErrorCode
)
const (
ERROR_MISSING_COLLSEQ ExtendedErrorCode = xErrorCode(ERROR) | (1 << 8)
ERROR_RETRY ExtendedErrorCode = xErrorCode(ERROR) | (2 << 8)
ERROR_SNAPSHOT ExtendedErrorCode = xErrorCode(ERROR) | (3 << 8)
IOERR_READ ExtendedErrorCode = xErrorCode(IOERR) | (1 << 8)
IOERR_SHORT_READ ExtendedErrorCode = xErrorCode(IOERR) | (2 << 8)
IOERR_WRITE ExtendedErrorCode = xErrorCode(IOERR) | (3 << 8)
IOERR_FSYNC ExtendedErrorCode = xErrorCode(IOERR) | (4 << 8)
IOERR_DIR_FSYNC ExtendedErrorCode = xErrorCode(IOERR) | (5 << 8)
IOERR_TRUNCATE ExtendedErrorCode = xErrorCode(IOERR) | (6 << 8)
IOERR_FSTAT ExtendedErrorCode = xErrorCode(IOERR) | (7 << 8)
IOERR_UNLOCK ExtendedErrorCode = xErrorCode(IOERR) | (8 << 8)
IOERR_RDLOCK ExtendedErrorCode = xErrorCode(IOERR) | (9 << 8)
IOERR_DELETE ExtendedErrorCode = xErrorCode(IOERR) | (10 << 8)
IOERR_BLOCKED ExtendedErrorCode = xErrorCode(IOERR) | (11 << 8)
IOERR_NOMEM ExtendedErrorCode = xErrorCode(IOERR) | (12 << 8)
IOERR_ACCESS ExtendedErrorCode = xErrorCode(IOERR) | (13 << 8)
IOERR_CHECKRESERVEDLOCK ExtendedErrorCode = xErrorCode(IOERR) | (14 << 8)
IOERR_LOCK ExtendedErrorCode = xErrorCode(IOERR) | (15 << 8)
IOERR_CLOSE ExtendedErrorCode = xErrorCode(IOERR) | (16 << 8)
IOERR_DIR_CLOSE ExtendedErrorCode = xErrorCode(IOERR) | (17 << 8)
IOERR_SHMOPEN ExtendedErrorCode = xErrorCode(IOERR) | (18 << 8)
IOERR_SHMSIZE ExtendedErrorCode = xErrorCode(IOERR) | (19 << 8)
IOERR_SHMLOCK ExtendedErrorCode = xErrorCode(IOERR) | (20 << 8)
IOERR_SHMMAP ExtendedErrorCode = xErrorCode(IOERR) | (21 << 8)
IOERR_SEEK ExtendedErrorCode = xErrorCode(IOERR) | (22 << 8)
IOERR_DELETE_NOENT ExtendedErrorCode = xErrorCode(IOERR) | (23 << 8)
IOERR_MMAP ExtendedErrorCode = xErrorCode(IOERR) | (24 << 8)
IOERR_GETTEMPPATH ExtendedErrorCode = xErrorCode(IOERR) | (25 << 8)
IOERR_CONVPATH ExtendedErrorCode = xErrorCode(IOERR) | (26 << 8)
IOERR_VNODE ExtendedErrorCode = xErrorCode(IOERR) | (27 << 8)
IOERR_AUTH ExtendedErrorCode = xErrorCode(IOERR) | (28 << 8)
IOERR_BEGIN_ATOMIC ExtendedErrorCode = xErrorCode(IOERR) | (29 << 8)
IOERR_COMMIT_ATOMIC ExtendedErrorCode = xErrorCode(IOERR) | (30 << 8)
IOERR_ROLLBACK_ATOMIC ExtendedErrorCode = xErrorCode(IOERR) | (31 << 8)
IOERR_DATA ExtendedErrorCode = xErrorCode(IOERR) | (32 << 8)
IOERR_CORRUPTFS ExtendedErrorCode = xErrorCode(IOERR) | (33 << 8)
IOERR_IN_PAGE ExtendedErrorCode = xErrorCode(IOERR) | (34 << 8)
LOCKED_SHAREDCACHE ExtendedErrorCode = xErrorCode(LOCKED) | (1 << 8)
LOCKED_VTAB ExtendedErrorCode = xErrorCode(LOCKED) | (2 << 8)
BUSY_RECOVERY ExtendedErrorCode = xErrorCode(BUSY) | (1 << 8)
BUSY_SNAPSHOT ExtendedErrorCode = xErrorCode(BUSY) | (2 << 8)
BUSY_TIMEOUT ExtendedErrorCode = xErrorCode(BUSY) | (3 << 8)
CANTOPEN_NOTEMPDIR ExtendedErrorCode = xErrorCode(CANTOPEN) | (1 << 8)
CANTOPEN_ISDIR ExtendedErrorCode = xErrorCode(CANTOPEN) | (2 << 8)
CANTOPEN_FULLPATH ExtendedErrorCode = xErrorCode(CANTOPEN) | (3 << 8)
CANTOPEN_CONVPATH ExtendedErrorCode = xErrorCode(CANTOPEN) | (4 << 8)
// CANTOPEN_DIRTYWAL ExtendedErrorCode = xErrorCode(CANTOPEN) | (5 << 8) /* Not Used */
CANTOPEN_SYMLINK ExtendedErrorCode = xErrorCode(CANTOPEN) | (6 << 8)
CORRUPT_VTAB ExtendedErrorCode = xErrorCode(CORRUPT) | (1 << 8)
CORRUPT_SEQUENCE ExtendedErrorCode = xErrorCode(CORRUPT) | (2 << 8)
CORRUPT_INDEX ExtendedErrorCode = xErrorCode(CORRUPT) | (3 << 8)
READONLY_RECOVERY ExtendedErrorCode = xErrorCode(READONLY) | (1 << 8)
READONLY_CANTLOCK ExtendedErrorCode = xErrorCode(READONLY) | (2 << 8)
READONLY_ROLLBACK ExtendedErrorCode = xErrorCode(READONLY) | (3 << 8)
READONLY_DBMOVED ExtendedErrorCode = xErrorCode(READONLY) | (4 << 8)
READONLY_CANTINIT ExtendedErrorCode = xErrorCode(READONLY) | (5 << 8)
READONLY_DIRECTORY ExtendedErrorCode = xErrorCode(READONLY) | (6 << 8)
ABORT_ROLLBACK ExtendedErrorCode = xErrorCode(ABORT) | (2 << 8)
CONSTRAINT_CHECK ExtendedErrorCode = xErrorCode(CONSTRAINT) | (1 << 8)
CONSTRAINT_COMMITHOOK ExtendedErrorCode = xErrorCode(CONSTRAINT) | (2 << 8)
CONSTRAINT_FOREIGNKEY ExtendedErrorCode = xErrorCode(CONSTRAINT) | (3 << 8)
CONSTRAINT_FUNCTION ExtendedErrorCode = xErrorCode(CONSTRAINT) | (4 << 8)
CONSTRAINT_NOTNULL ExtendedErrorCode = xErrorCode(CONSTRAINT) | (5 << 8)
CONSTRAINT_PRIMARYKEY ExtendedErrorCode = xErrorCode(CONSTRAINT) | (6 << 8)
CONSTRAINT_TRIGGER ExtendedErrorCode = xErrorCode(CONSTRAINT) | (7 << 8)
CONSTRAINT_UNIQUE ExtendedErrorCode = xErrorCode(CONSTRAINT) | (8 << 8)
CONSTRAINT_VTAB ExtendedErrorCode = xErrorCode(CONSTRAINT) | (9 << 8)
CONSTRAINT_ROWID ExtendedErrorCode = xErrorCode(CONSTRAINT) | (10 << 8)
CONSTRAINT_PINNED ExtendedErrorCode = xErrorCode(CONSTRAINT) | (11 << 8)
CONSTRAINT_DATATYPE ExtendedErrorCode = xErrorCode(CONSTRAINT) | (12 << 8)
NOTICE_RECOVER_WAL ExtendedErrorCode = xErrorCode(NOTICE) | (1 << 8)
NOTICE_RECOVER_ROLLBACK ExtendedErrorCode = xErrorCode(NOTICE) | (2 << 8)
NOTICE_RBU ExtendedErrorCode = xErrorCode(NOTICE) | (3 << 8)
WARNING_AUTOINDEX ExtendedErrorCode = xErrorCode(WARNING) | (1 << 8)
AUTH_USER ExtendedErrorCode = xErrorCode(AUTH) | (1 << 8)
)
// OpenFlag is a flag for the [OpenFlags] function.
//
// https://sqlite.org/c3ref/c_open_autoproxy.html
type OpenFlag uint32
const (
OPEN_READONLY OpenFlag = 0x00000001 /* Ok for sqlite3_open_v2() */
OPEN_READWRITE OpenFlag = 0x00000002 /* Ok for sqlite3_open_v2() */
OPEN_CREATE OpenFlag = 0x00000004 /* Ok for sqlite3_open_v2() */
OPEN_URI OpenFlag = 0x00000040 /* Ok for sqlite3_open_v2() */
OPEN_MEMORY OpenFlag = 0x00000080 /* Ok for sqlite3_open_v2() */
OPEN_NOMUTEX OpenFlag = 0x00008000 /* Ok for sqlite3_open_v2() */
OPEN_FULLMUTEX OpenFlag = 0x00010000 /* Ok for sqlite3_open_v2() */
OPEN_SHAREDCACHE OpenFlag = 0x00020000 /* Ok for sqlite3_open_v2() */
OPEN_PRIVATECACHE OpenFlag = 0x00040000 /* Ok for sqlite3_open_v2() */
OPEN_NOFOLLOW OpenFlag = 0x01000000 /* Ok for sqlite3_open_v2() */
OPEN_EXRESCODE OpenFlag = 0x02000000 /* Extended result codes */
)
// PrepareFlag is a flag that can be passed to [Conn.PrepareFlags].
//
// https://sqlite.org/c3ref/c_prepare_normalize.html
type PrepareFlag uint32
const (
PREPARE_PERSISTENT PrepareFlag = 0x01
PREPARE_NORMALIZE PrepareFlag = 0x02
PREPARE_NO_VTAB PrepareFlag = 0x04
PREPARE_DONT_LOG PrepareFlag = 0x10
)
// FunctionFlag is a flag that can be passed to
// [Conn.CreateFunction] and [Conn.CreateWindowFunction].
//
// https://sqlite.org/c3ref/c_deterministic.html
type FunctionFlag uint32
const (
DETERMINISTIC FunctionFlag = 0x000000800
DIRECTONLY FunctionFlag = 0x000080000
INNOCUOUS FunctionFlag = 0x000200000
SELFORDER1 FunctionFlag = 0x002000000
// SUBTYPE FunctionFlag = 0x000100000
// RESULT_SUBTYPE FunctionFlag = 0x001000000
)
// StmtStatus name counter values associated with the [Stmt.Status] method.
//
// https://sqlite.org/c3ref/c_stmtstatus_counter.html
type StmtStatus uint32
const (
STMTSTATUS_FULLSCAN_STEP StmtStatus = 1
STMTSTATUS_SORT StmtStatus = 2
STMTSTATUS_AUTOINDEX StmtStatus = 3
STMTSTATUS_VM_STEP StmtStatus = 4
STMTSTATUS_REPREPARE StmtStatus = 5
STMTSTATUS_RUN StmtStatus = 6
STMTSTATUS_FILTER_MISS StmtStatus = 7
STMTSTATUS_FILTER_HIT StmtStatus = 8
STMTSTATUS_MEMUSED StmtStatus = 99
)
// DBStatus are the available "verbs" that can be passed to the [Conn.Status] method.
//
// https://sqlite.org/c3ref/c_dbstatus_options.html
type DBStatus uint32
const (
DBSTATUS_LOOKASIDE_USED DBStatus = 0
DBSTATUS_CACHE_USED DBStatus = 1
DBSTATUS_SCHEMA_USED DBStatus = 2
DBSTATUS_STMT_USED DBStatus = 3
DBSTATUS_LOOKASIDE_HIT DBStatus = 4
DBSTATUS_LOOKASIDE_MISS_SIZE DBStatus = 5
DBSTATUS_LOOKASIDE_MISS_FULL DBStatus = 6
DBSTATUS_CACHE_HIT DBStatus = 7
DBSTATUS_CACHE_MISS DBStatus = 8
DBSTATUS_CACHE_WRITE DBStatus = 9
DBSTATUS_DEFERRED_FKS DBStatus = 10
DBSTATUS_CACHE_USED_SHARED DBStatus = 11
DBSTATUS_CACHE_SPILL DBStatus = 12
// DBSTATUS_MAX DBStatus = 12
)
// DBConfig are the available database connection configuration options.
//
// https://sqlite.org/c3ref/c_dbconfig_defensive.html
type DBConfig uint32
const (
// DBCONFIG_MAINDBNAME DBConfig = 1000
// DBCONFIG_LOOKASIDE DBConfig = 1001
DBCONFIG_ENABLE_FKEY DBConfig = 1002
DBCONFIG_ENABLE_TRIGGER DBConfig = 1003
DBCONFIG_ENABLE_FTS3_TOKENIZER DBConfig = 1004
DBCONFIG_ENABLE_LOAD_EXTENSION DBConfig = 1005
DBCONFIG_NO_CKPT_ON_CLOSE DBConfig = 1006
DBCONFIG_ENABLE_QPSG DBConfig = 1007
DBCONFIG_TRIGGER_EQP DBConfig = 1008
DBCONFIG_RESET_DATABASE DBConfig = 1009
DBCONFIG_DEFENSIVE DBConfig = 1010
DBCONFIG_WRITABLE_SCHEMA DBConfig = 1011
DBCONFIG_LEGACY_ALTER_TABLE DBConfig = 1012
DBCONFIG_DQS_DML DBConfig = 1013
DBCONFIG_DQS_DDL DBConfig = 1014
DBCONFIG_ENABLE_VIEW DBConfig = 1015
DBCONFIG_LEGACY_FILE_FORMAT DBConfig = 1016
DBCONFIG_TRUSTED_SCHEMA DBConfig = 1017
DBCONFIG_STMT_SCANSTATUS DBConfig = 1018
DBCONFIG_REVERSE_SCANORDER DBConfig = 1019
DBCONFIG_ENABLE_ATTACH_CREATE DBConfig = 1020
DBCONFIG_ENABLE_ATTACH_WRITE DBConfig = 1021
DBCONFIG_ENABLE_COMMENTS DBConfig = 1022
// DBCONFIG_MAX DBConfig = 1022
)
// FcntlOpcode are the available opcodes for [Conn.FileControl].
//
// https://sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
type FcntlOpcode uint32
const (
FCNTL_LOCKSTATE FcntlOpcode = 1
FCNTL_CHUNK_SIZE FcntlOpcode = 6
FCNTL_FILE_POINTER FcntlOpcode = 7
FCNTL_PERSIST_WAL FcntlOpcode = 10
FCNTL_POWERSAFE_OVERWRITE FcntlOpcode = 13
FCNTL_VFS_POINTER FcntlOpcode = 27
FCNTL_JOURNAL_POINTER FcntlOpcode = 28
FCNTL_DATA_VERSION FcntlOpcode = 35
FCNTL_RESERVE_BYTES FcntlOpcode = 38
FCNTL_RESET_CACHE FcntlOpcode = 42
)
// LimitCategory are the available run-time limit categories.
//
// https://sqlite.org/c3ref/c_limit_attached.html
type LimitCategory uint32
const (
LIMIT_LENGTH LimitCategory = 0
LIMIT_SQL_LENGTH LimitCategory = 1
LIMIT_COLUMN LimitCategory = 2
LIMIT_EXPR_DEPTH LimitCategory = 3
LIMIT_COMPOUND_SELECT LimitCategory = 4
LIMIT_VDBE_OP LimitCategory = 5
LIMIT_FUNCTION_ARG LimitCategory = 6
LIMIT_ATTACHED LimitCategory = 7
LIMIT_LIKE_PATTERN_LENGTH LimitCategory = 8
LIMIT_VARIABLE_NUMBER LimitCategory = 9
LIMIT_TRIGGER_DEPTH LimitCategory = 10
LIMIT_WORKER_THREADS LimitCategory = 11
)
// AuthorizerActionCode are the integer action codes
// that the authorizer callback may be passed.
//
// https://sqlite.org/c3ref/c_alter_table.html
type AuthorizerActionCode uint32
const (
/***************************************************** 3rd ************ 4th ***********/
AUTH_CREATE_INDEX AuthorizerActionCode = 1 /* Index Name Table Name */
AUTH_CREATE_TABLE AuthorizerActionCode = 2 /* Table Name NULL */
AUTH_CREATE_TEMP_INDEX AuthorizerActionCode = 3 /* Index Name Table Name */
AUTH_CREATE_TEMP_TABLE AuthorizerActionCode = 4 /* Table Name NULL */
AUTH_CREATE_TEMP_TRIGGER AuthorizerActionCode = 5 /* Trigger Name Table Name */
AUTH_CREATE_TEMP_VIEW AuthorizerActionCode = 6 /* View Name NULL */
AUTH_CREATE_TRIGGER AuthorizerActionCode = 7 /* Trigger Name Table Name */
AUTH_CREATE_VIEW AuthorizerActionCode = 8 /* View Name NULL */
AUTH_DELETE AuthorizerActionCode = 9 /* Table Name NULL */
AUTH_DROP_INDEX AuthorizerActionCode = 10 /* Index Name Table Name */
AUTH_DROP_TABLE AuthorizerActionCode = 11 /* Table Name NULL */
AUTH_DROP_TEMP_INDEX AuthorizerActionCode = 12 /* Index Name Table Name */
AUTH_DROP_TEMP_TABLE AuthorizerActionCode = 13 /* Table Name NULL */
AUTH_DROP_TEMP_TRIGGER AuthorizerActionCode = 14 /* Trigger Name Table Name */
AUTH_DROP_TEMP_VIEW AuthorizerActionCode = 15 /* View Name NULL */
AUTH_DROP_TRIGGER AuthorizerActionCode = 16 /* Trigger Name Table Name */
AUTH_DROP_VIEW AuthorizerActionCode = 17 /* View Name NULL */
AUTH_INSERT AuthorizerActionCode = 18 /* Table Name NULL */
AUTH_PRAGMA AuthorizerActionCode = 19 /* Pragma Name 1st arg or NULL */
AUTH_READ AuthorizerActionCode = 20 /* Table Name Column Name */
AUTH_SELECT AuthorizerActionCode = 21 /* NULL NULL */
AUTH_TRANSACTION AuthorizerActionCode = 22 /* Operation NULL */
AUTH_UPDATE AuthorizerActionCode = 23 /* Table Name Column Name */
AUTH_ATTACH AuthorizerActionCode = 24 /* Filename NULL */
AUTH_DETACH AuthorizerActionCode = 25 /* Database Name NULL */
AUTH_ALTER_TABLE AuthorizerActionCode = 26 /* Database Name Table Name */
AUTH_REINDEX AuthorizerActionCode = 27 /* Index Name NULL */
AUTH_ANALYZE AuthorizerActionCode = 28 /* Table Name NULL */
AUTH_CREATE_VTABLE AuthorizerActionCode = 29 /* Table Name Module Name */
AUTH_DROP_VTABLE AuthorizerActionCode = 30 /* Table Name Module Name */
AUTH_FUNCTION AuthorizerActionCode = 31 /* NULL Function Name */
AUTH_SAVEPOINT AuthorizerActionCode = 32 /* Operation Savepoint Name */
AUTH_RECURSIVE AuthorizerActionCode = 33 /* NULL NULL */
// AUTH_COPY AuthorizerActionCode = 0 /* No longer used */
)
// AuthorizerReturnCode are the integer codes
// that the authorizer callback may return.
//
// https://sqlite.org/c3ref/c_deny.html
type AuthorizerReturnCode uint32
const (
AUTH_OK AuthorizerReturnCode = 0
AUTH_DENY AuthorizerReturnCode = 1 /* Abort the SQL statement with an error */
AUTH_IGNORE AuthorizerReturnCode = 2 /* Don't allow access, but don't generate an error */
)
// CheckpointMode are all the checkpoint mode values.
//
// https://sqlite.org/c3ref/c_checkpoint_full.html
type CheckpointMode uint32
const (
CHECKPOINT_PASSIVE CheckpointMode = 0 /* Do as much as possible w/o blocking */
CHECKPOINT_FULL CheckpointMode = 1 /* Wait for writers, then checkpoint */
CHECKPOINT_RESTART CheckpointMode = 2 /* Like FULL but wait for readers */
CHECKPOINT_TRUNCATE CheckpointMode = 3 /* Like RESTART but also truncate WAL */
)
// TxnState are the allowed return values from [Conn.TxnState].
//
// https://sqlite.org/c3ref/c_txn_none.html
type TxnState uint32
const (
TXN_NONE TxnState = 0
TXN_READ TxnState = 1
TXN_WRITE TxnState = 2
)
// TraceEvent identify classes of events that can be monitored with [Conn.Trace].
//
// https://sqlite.org/c3ref/c_trace.html
type TraceEvent uint32
const (
TRACE_STMT TraceEvent = 0x01
TRACE_PROFILE TraceEvent = 0x02
TRACE_ROW TraceEvent = 0x04
TRACE_CLOSE TraceEvent = 0x08
)
// Datatype is a fundamental datatype of SQLite.
//
// https://sqlite.org/c3ref/c_blob.html
type Datatype uint32
const (
INTEGER Datatype = 1
FLOAT Datatype = 2
TEXT Datatype = 3
BLOB Datatype = 4
NULL Datatype = 5
)
// String implements the [fmt.Stringer] interface.
func (t Datatype) String() string {
const name = "INTEGERFLOATEXTBLOBNULL"
switch t {
case INTEGER:
return name[0:7]
case FLOAT:
return name[7:12]
case TEXT:
return name[11:15]
case BLOB:
return name[15:19]
case NULL:
return name[19:23]
}
return strconv.FormatUint(uint64(t), 10)
}
package sqlite3
import (
"encoding/json"
"errors"
"math"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Context is the context in which an SQL function executes.
// An SQLite [Context] is in no way related to a Go [context.Context].
//
// https://sqlite.org/c3ref/context.html
type Context struct {
c *Conn
handle ptr_t
}
// Conn returns the database connection of the
// [Conn.CreateFunction] or [Conn.CreateWindowFunction]
// routines that originally registered the application defined function.
//
// https://sqlite.org/c3ref/context_db_handle.html
func (ctx Context) Conn() *Conn {
return ctx.c
}
// SetAuxData saves metadata for argument n of the function.
//
// https://sqlite.org/c3ref/get_auxdata.html
func (ctx Context) SetAuxData(n int, data any) {
ptr := util.AddHandle(ctx.c.ctx, data)
ctx.c.call("sqlite3_set_auxdata_go", stk_t(ctx.handle), stk_t(n), stk_t(ptr))
}
// GetAuxData returns metadata for argument n of the function.
//
// https://sqlite.org/c3ref/get_auxdata.html
func (ctx Context) GetAuxData(n int) any {
ptr := ptr_t(ctx.c.call("sqlite3_get_auxdata", stk_t(ctx.handle), stk_t(n)))
return util.GetHandle(ctx.c.ctx, ptr)
}
// ResultBool sets the result of the function to a bool.
// SQLite does not have a separate boolean storage class.
// Instead, boolean values are stored as integers 0 (false) and 1 (true).
//
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultBool(value bool) {
var i int64
if value {
i = 1
}
ctx.ResultInt64(i)
}
// ResultInt sets the result of the function to an int.
//
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultInt(value int) {
ctx.ResultInt64(int64(value))
}
// ResultInt64 sets the result of the function to an int64.
//
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultInt64(value int64) {
ctx.c.call("sqlite3_result_int64",
stk_t(ctx.handle), stk_t(value))
}
// ResultFloat sets the result of the function to a float64.
//
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultFloat(value float64) {
ctx.c.call("sqlite3_result_double",
stk_t(ctx.handle), stk_t(math.Float64bits(value)))
}
// ResultText sets the result of the function to a string.
//
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultText(value string) {
ptr := ctx.c.newString(value)
ctx.c.call("sqlite3_result_text_go",
stk_t(ctx.handle), stk_t(ptr), stk_t(len(value)))
}
// ResultRawText sets the text result of the function to a []byte.
// Returning a nil slice is the same as calling [Context.ResultNull].
//
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultRawText(value []byte) {
ptr := ctx.c.newBytes(value)
ctx.c.call("sqlite3_result_text_go",
stk_t(ctx.handle), stk_t(ptr), stk_t(len(value)))
}
// ResultBlob sets the result of the function to a []byte.
// Returning a nil slice is the same as calling [Context.ResultNull].
//
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultBlob(value []byte) {
ptr := ctx.c.newBytes(value)
ctx.c.call("sqlite3_result_blob_go",
stk_t(ctx.handle), stk_t(ptr), stk_t(len(value)))
}
// ResultZeroBlob sets the result of the function to a zero-filled, length n BLOB.
//
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultZeroBlob(n int64) {
ctx.c.call("sqlite3_result_zeroblob64",
stk_t(ctx.handle), stk_t(n))
}
// ResultNull sets the result of the function to NULL.
//
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultNull() {
ctx.c.call("sqlite3_result_null",
stk_t(ctx.handle))
}
// ResultTime sets the result of the function to a [time.Time].
//
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultTime(value time.Time, format TimeFormat) {
switch format {
case TimeFormatDefault, TimeFormatAuto, time.RFC3339Nano:
ctx.resultRFC3339Nano(value)
return
}
switch v := format.Encode(value).(type) {
case string:
ctx.ResultText(v)
case int64:
ctx.ResultInt64(v)
case float64:
ctx.ResultFloat(v)
default:
panic(util.AssertErr())
}
}
func (ctx Context) resultRFC3339Nano(value time.Time) {
const maxlen = int64(len(time.RFC3339Nano)) + 5
ptr := ctx.c.new(maxlen)
buf := util.View(ctx.c.mod, ptr, maxlen)
buf = value.AppendFormat(buf[:0], time.RFC3339Nano)
ctx.c.call("sqlite3_result_text_go",
stk_t(ctx.handle), stk_t(ptr), stk_t(len(buf)))
}
// ResultPointer sets the result of the function to NULL, just like [Context.ResultNull],
// except that it also associates ptr with that NULL value such that it can be retrieved
// within an application-defined SQL function using [Value.Pointer].
//
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultPointer(ptr any) {
valPtr := util.AddHandle(ctx.c.ctx, ptr)
ctx.c.call("sqlite3_result_pointer_go",
stk_t(ctx.handle), stk_t(valPtr))
}
// ResultJSON sets the result of the function to the JSON encoding of value.
//
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultJSON(value any) {
data, err := json.Marshal(value)
if err != nil {
ctx.ResultError(err)
return // notest
}
ctx.ResultRawText(data)
}
// ResultValue sets the result of the function to a copy of [Value].
//
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultValue(value Value) {
if value.c != ctx.c {
ctx.ResultError(MISUSE)
return
}
ctx.c.call("sqlite3_result_value",
stk_t(ctx.handle), stk_t(value.handle))
}
// ResultError sets the result of the function an error.
//
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultError(err error) {
if errors.Is(err, NOMEM) {
ctx.c.call("sqlite3_result_error_nomem", stk_t(ctx.handle))
return
}
if errors.Is(err, TOOBIG) {
ctx.c.call("sqlite3_result_error_toobig", stk_t(ctx.handle))
return
}
msg, code := errorCode(err, _OK)
if msg != "" {
defer ctx.c.arena.mark()()
ptr := ctx.c.arena.string(msg)
ctx.c.call("sqlite3_result_error",
stk_t(ctx.handle), stk_t(ptr), stk_t(len(msg)))
}
if code != _OK {
ctx.c.call("sqlite3_result_error_code",
stk_t(ctx.handle), stk_t(code))
}
}
// VTabNoChange may return true if a column is being fetched as part
// of an update during which the column value will not change.
//
// https://sqlite.org/c3ref/vtab_nochange.html
func (ctx Context) VTabNoChange() bool {
b := int32(ctx.c.call("sqlite3_vtab_nochange", stk_t(ctx.handle)))
return b != 0
}
// Package driver provides a database/sql driver for SQLite.
//
// Importing package driver registers a [database/sql] driver named "sqlite3".
// You may also need to import package embed.
//
// import _ "github.com/ncruces/go-sqlite3/driver"
// import _ "github.com/ncruces/go-sqlite3/embed"
//
// The data source name for "sqlite3" databases can be a filename or a "file:" [URI].
//
// # Default transaction mode
//
// The [TRANSACTION] mode can be specified using "_txlock":
//
// sql.Open("sqlite3", "file:demo.db?_txlock=immediate")
//
// Possible values are: "deferred" (the default), "immediate", "exclusive".
// Regardless of "_txlock":
// - a [linearizable] transaction is always "exclusive";
// - a [serializable] transaction is always "immediate";
// - a [read-only] transaction is always "deferred".
//
// # Working with time
//
// The time encoding/decoding format can be specified using "_timefmt":
//
// sql.Open("sqlite3", "file:demo.db?_timefmt=sqlite")
//
// Possible values are: "auto" (the default), "sqlite", "rfc3339";
// - "auto" encodes as RFC 3339 and decodes any [format] supported by SQLite;
// - "sqlite" encodes as SQLite and decodes any [format] supported by SQLite;
// - "rfc3339" encodes and decodes RFC 3339 only.
//
// If you encode as RFC 3339 (the default),
// consider using the TIME [collating sequence] to produce a time-ordered sequence.
//
// To scan values in other formats, [sqlite3.TimeFormat.Scanner] may be helpful.
// To bind values in other formats, [sqlite3.TimeFormat.Encode] them before binding.
//
// When using a custom time struct, you'll have to implement
// [database/sql/driver.Valuer] and [database/sql.Scanner].
//
// The Value method should ideally encode to a time [format] supported by SQLite.
// This ensures SQL date and time functions work as they should,
// and that your schema works with other SQLite tools.
// [sqlite3.TimeFormat.Encode] may help.
//
// The Scan method needs to take into account that the value it receives can be of differing types.
// It can already be a [time.Time], if the driver decoded the value according to "_timefmt" rules.
// Or it can be a: string, int64, float64, []byte, or nil,
// depending on the column type and what whoever wrote the value.
// [sqlite3.TimeFormat.Decode] may help.
//
// # Setting PRAGMAs
//
// [PRAGMA] statements can be specified using "_pragma":
//
// sql.Open("sqlite3", "file:demo.db?_pragma=busy_timeout(10000)")
//
// If no PRAGMAs are specified, a busy timeout of 1 minute is set.
//
// Order matters:
// encryption keys, busy timeout and locking mode should be the first PRAGMAs set,
// in that order.
//
// [URI]: https://sqlite.org/uri.html
// [PRAGMA]: https://sqlite.org/pragma.html
// [TRANSACTION]: https://sqlite.org/lang_transaction.html#deferred_immediate_and_exclusive_transactions
// [linearizable]: https://pkg.go.dev/database/sql#TxOptions
// [serializable]: https://pkg.go.dev/database/sql#TxOptions
// [read-only]: https://pkg.go.dev/database/sql#TxOptions
// [format]: https://sqlite.org/lang_datefunc.html#time_values
// [collating sequence]: https://sqlite.org/datatype3.html#collating_sequences
package driver
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"io"
"net/url"
"reflect"
"strings"
"time"
"unsafe"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
// This variable can be replaced with -ldflags:
//
// go build -ldflags="-X github.com/ncruces/go-sqlite3/driver.driverName=sqlite"
var driverName = "sqlite3"
func init() {
if driverName != "" {
sql.Register(driverName, &SQLite{})
}
}
// Open opens the SQLite database specified by dataSourceName as a [database/sql.DB].
//
// Open accepts zero, one, or two callbacks (nil callbacks are ignored).
// The first callback is called when the driver opens a new connection.
// The second callback is called before the driver closes a connection.
// The [sqlite3.Conn] can be used to execute queries, register functions, etc.
func Open(dataSourceName string, fn ...func(*sqlite3.Conn) error) (*sql.DB, error) {
if len(fn) > 2 {
return nil, sqlite3.MISUSE
}
var init, term func(*sqlite3.Conn) error
if len(fn) > 1 {
term = fn[1]
}
if len(fn) > 0 {
init = fn[0]
}
c, err := newConnector(dataSourceName, init, term)
if err != nil {
return nil, err
}
return sql.OpenDB(c), nil
}
// SQLite implements [database/sql/driver.Driver].
type SQLite struct{}
var (
// Ensure these interfaces are implemented:
_ driver.DriverContext = &SQLite{}
)
// Open implements [database/sql/driver.Driver].
func (d *SQLite) Open(name string) (driver.Conn, error) {
c, err := newConnector(name, nil, nil)
if err != nil {
return nil, err
}
return c.Connect(context.Background())
}
// OpenConnector implements [database/sql/driver.DriverContext].
func (d *SQLite) OpenConnector(name string) (driver.Connector, error) {
return newConnector(name, nil, nil)
}
func newConnector(name string, init, term func(*sqlite3.Conn) error) (*connector, error) {
c := connector{name: name, init: init, term: term}
var txlock, timefmt string
if strings.HasPrefix(name, "file:") {
if _, after, ok := strings.Cut(name, "?"); ok {
query, err := url.ParseQuery(after)
if err != nil {
return nil, err
}
txlock = query.Get("_txlock")
timefmt = query.Get("_timefmt")
c.pragmas = query.Has("_pragma")
}
}
switch txlock {
case "", "deferred", "concurrent", "immediate", "exclusive":
c.txLock = txlock
default:
return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", txlock)
}
switch timefmt {
case "":
c.tmRead = sqlite3.TimeFormatAuto
c.tmWrite = sqlite3.TimeFormatDefault
case "sqlite":
c.tmRead = sqlite3.TimeFormatAuto
c.tmWrite = sqlite3.TimeFormat3
case "rfc3339":
c.tmRead = sqlite3.TimeFormatDefault
c.tmWrite = sqlite3.TimeFormatDefault
default:
c.tmRead = sqlite3.TimeFormat(timefmt)
c.tmWrite = sqlite3.TimeFormat(timefmt)
}
return &c, nil
}
type connector struct {
init func(*sqlite3.Conn) error
term func(*sqlite3.Conn) error
name string
txLock string
tmRead sqlite3.TimeFormat
tmWrite sqlite3.TimeFormat
pragmas bool
}
func (n *connector) Driver() driver.Driver {
return &SQLite{}
}
func (n *connector) Connect(ctx context.Context) (ret driver.Conn, err error) {
c := &conn{
txLock: n.txLock,
tmRead: n.tmRead,
tmWrite: n.tmWrite,
}
c.Conn, err = sqlite3.OpenContext(ctx, n.name)
if err != nil {
return nil, err
}
defer func() {
if ret == nil {
c.Close()
}
}()
old := c.Conn.SetInterrupt(ctx)
defer c.Conn.SetInterrupt(old)
if !n.pragmas {
err = c.Conn.BusyTimeout(time.Minute)
if err != nil {
return nil, err
}
}
if n.init != nil {
err = n.init(c.Conn)
if err != nil {
return nil, err
}
}
if n.pragmas || n.init != nil {
s, _, err := c.Conn.Prepare(`PRAGMA query_only`)
if err != nil {
return nil, err
}
defer s.Close()
if s.Step() && s.ColumnBool(0) {
c.readOnly = '1'
} else {
c.readOnly = '0'
}
err = s.Close()
if err != nil {
return nil, err
}
}
if n.term != nil {
err = c.Conn.Trace(sqlite3.TRACE_CLOSE, func(sqlite3.TraceEvent, any, any) error {
return n.term(c.Conn)
})
if err != nil {
return nil, err
}
}
return c, nil
}
// Conn is implemented by the SQLite [database/sql] driver connections.
//
// It can be used to access SQLite features like [online backup]:
//
// db, err := driver.Open("temp.db")
// if err != nil {
// log.Fatal(err)
// }
// defer db.Close()
//
// conn, err := db.Conn(context.TODO())
// if err != nil {
// log.Fatal(err)
// }
// defer conn.Close()
//
// err = conn.Raw(func(driverConn any) error {
// conn := driverConn.(driver.Conn)
// return conn.Raw().Backup("main", "backup.db")
// })
// if err != nil {
// log.Fatal(err)
// }
//
// [online backup]: https://sqlite.org/backup.html
type Conn interface {
Raw() *sqlite3.Conn
driver.Conn
driver.ConnBeginTx
driver.ConnPrepareContext
}
type conn struct {
*sqlite3.Conn
txLock string
txReset string
tmRead sqlite3.TimeFormat
tmWrite sqlite3.TimeFormat
readOnly byte
}
var (
// Ensure these interfaces are implemented:
_ Conn = &conn{}
_ driver.ExecerContext = &conn{}
)
func (c *conn) Raw() *sqlite3.Conn {
return c.Conn
}
// Deprecated: use BeginTx instead.
func (c *conn) Begin() (driver.Tx, error) {
// notest
return c.BeginTx(context.Background(), driver.TxOptions{})
}
func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
var txLock string
switch opts.Isolation {
default:
return nil, util.IsolationErr
case driver.IsolationLevel(sql.LevelLinearizable):
txLock = "exclusive"
case driver.IsolationLevel(sql.LevelSerializable):
txLock = "immediate"
case driver.IsolationLevel(sql.LevelDefault):
if !opts.ReadOnly {
txLock = c.txLock
}
}
c.txReset = ``
txBegin := `BEGIN ` + txLock
if opts.ReadOnly {
txBegin += ` ; PRAGMA query_only=on`
c.txReset = `; PRAGMA query_only=` + string(c.readOnly)
}
old := c.Conn.SetInterrupt(ctx)
defer c.Conn.SetInterrupt(old)
err := c.Conn.Exec(txBegin)
if err != nil {
return nil, err
}
return c, nil
}
func (c *conn) Commit() error {
err := c.Conn.Exec(`COMMIT` + c.txReset)
if err != nil && !c.Conn.GetAutocommit() {
c.Rollback()
}
return err
}
func (c *conn) Rollback() error {
// ROLLBACK even if interrupted.
old := c.Conn.SetInterrupt(context.Background())
defer c.Conn.SetInterrupt(old)
return c.Conn.Exec(`ROLLBACK` + c.txReset)
}
func (c *conn) Prepare(query string) (driver.Stmt, error) {
// notest
return c.PrepareContext(context.Background(), query)
}
func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
old := c.Conn.SetInterrupt(ctx)
defer c.Conn.SetInterrupt(old)
s, tail, err := c.Conn.Prepare(query)
if err != nil {
return nil, err
}
if notWhitespace(tail) {
s.Close()
return nil, util.TailErr
}
return &stmt{Stmt: s, tmRead: c.tmRead, tmWrite: c.tmWrite, inputs: -2}, nil
}
func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
if len(args) != 0 {
// Slow path.
return nil, driver.ErrSkip
}
if savept, ok := ctx.(*saveptCtx); ok {
// Called from driver.Savepoint.
savept.Savepoint = c.Conn.Savepoint()
return resultRowsAffected(0), nil
}
old := c.Conn.SetInterrupt(ctx)
defer c.Conn.SetInterrupt(old)
err := c.Conn.Exec(query)
if err != nil {
return nil, err
}
return newResult(c.Conn), nil
}
func (c *conn) CheckNamedValue(arg *driver.NamedValue) error {
// Fast path: short circuit argument verification.
// Arguments will be rejected by conn.ExecContext.
return nil
}
type stmt struct {
*sqlite3.Stmt
tmWrite sqlite3.TimeFormat
tmRead sqlite3.TimeFormat
inputs int
}
var (
// Ensure these interfaces are implemented:
_ driver.StmtExecContext = &stmt{}
_ driver.StmtQueryContext = &stmt{}
_ driver.NamedValueChecker = &stmt{}
)
func (s *stmt) NumInput() int {
if s.inputs >= -1 {
return s.inputs
}
n := s.Stmt.BindCount()
for i := 1; i <= n; i++ {
if s.Stmt.BindName(i) != "" {
s.inputs = -1
return -1
}
}
s.inputs = n
return n
}
// Deprecated: use ExecContext instead.
func (s *stmt) Exec(args []driver.Value) (driver.Result, error) {
// notest
return s.ExecContext(context.Background(), namedValues(args))
}
// Deprecated: use QueryContext instead.
func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
// notest
return s.QueryContext(context.Background(), namedValues(args))
}
func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
err := s.setupBindings(args)
if err != nil {
return nil, err
}
old := s.Stmt.Conn().SetInterrupt(ctx)
defer s.Stmt.Conn().SetInterrupt(old)
err = errors.Join(
s.Stmt.Exec(),
s.Stmt.ClearBindings())
if err != nil {
return nil, err
}
return newResult(s.Stmt.Conn()), nil
}
func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
err := s.setupBindings(args)
if err != nil {
return nil, err
}
return &rows{ctx: ctx, stmt: s}, nil
}
func (s *stmt) setupBindings(args []driver.NamedValue) (err error) {
var ids [3]int
for _, arg := range args {
ids := ids[:0]
if arg.Name == "" {
ids = append(ids, arg.Ordinal)
} else {
for _, prefix := range [...]string{":", "@", "$"} {
if id := s.Stmt.BindIndex(prefix + arg.Name); id != 0 {
ids = append(ids, id)
}
}
}
for _, id := range ids {
switch a := arg.Value.(type) {
case bool:
err = s.Stmt.BindBool(id, a)
case int:
err = s.Stmt.BindInt(id, a)
case int64:
err = s.Stmt.BindInt64(id, a)
case float64:
err = s.Stmt.BindFloat(id, a)
case string:
err = s.Stmt.BindText(id, a)
case []byte:
err = s.Stmt.BindBlob(id, a)
case sqlite3.ZeroBlob:
err = s.Stmt.BindZeroBlob(id, int64(a))
case time.Time:
err = s.Stmt.BindTime(id, a, s.tmWrite)
case util.JSON:
err = s.Stmt.BindJSON(id, a.Value)
case util.PointerUnwrap:
err = s.Stmt.BindPointer(id, util.UnwrapPointer(a))
case nil:
err = s.Stmt.BindNull(id)
default:
panic(util.AssertErr())
}
if err != nil {
return err
}
}
}
return nil
}
func (s *stmt) CheckNamedValue(arg *driver.NamedValue) error {
switch arg.Value.(type) {
case bool, int, int64, float64, string, []byte,
time.Time, sqlite3.ZeroBlob,
util.JSON, util.PointerUnwrap,
nil:
return nil
default:
return driver.ErrSkip
}
}
func newResult(c *sqlite3.Conn) driver.Result {
rows := c.Changes()
if rows != 0 {
id := c.LastInsertRowID()
if id != 0 {
return result{id, rows}
}
}
return resultRowsAffected(rows)
}
type result struct{ lastInsertId, rowsAffected int64 }
func (r result) LastInsertId() (int64, error) {
return r.lastInsertId, nil
}
func (r result) RowsAffected() (int64, error) {
return r.rowsAffected, nil
}
type resultRowsAffected int64
func (r resultRowsAffected) LastInsertId() (int64, error) {
return 0, nil
}
func (r resultRowsAffected) RowsAffected() (int64, error) {
return int64(r), nil
}
type rows struct {
ctx context.Context
*stmt
names []string
types []string
nulls []bool
scans []scantype
}
type scantype byte
const (
_ANY scantype = iota
_INT scantype = scantype(sqlite3.INTEGER)
_REAL scantype = scantype(sqlite3.FLOAT)
_TEXT scantype = scantype(sqlite3.TEXT)
_BLOB scantype = scantype(sqlite3.BLOB)
_NULL scantype = scantype(sqlite3.NULL)
_BOOL scantype = iota
_TIME
)
var (
// Ensure these interfaces are implemented:
_ driver.RowsColumnTypeDatabaseTypeName = &rows{}
_ driver.RowsColumnTypeNullable = &rows{}
)
func (r *rows) Close() error {
return errors.Join(
r.Stmt.Reset(),
r.Stmt.ClearBindings())
}
func (r *rows) Columns() []string {
if r.names == nil {
count := r.Stmt.ColumnCount()
names := make([]string, count)
for i := range names {
names[i] = r.Stmt.ColumnName(i)
}
r.names = names
}
return r.names
}
func (r *rows) loadColumnMetadata() {
if r.nulls == nil {
count := r.Stmt.ColumnCount()
nulls := make([]bool, count)
types := make([]string, count)
scans := make([]scantype, count)
for i := range nulls {
if col := r.Stmt.ColumnOriginName(i); col != "" {
types[i], _, nulls[i], _, _, _ = r.Stmt.Conn().TableColumnMetadata(
r.Stmt.ColumnDatabaseName(i),
r.Stmt.ColumnTableName(i),
col)
types[i] = strings.ToUpper(types[i])
// These types are only used before we have rows,
// and otherwise as type hints.
// The first few ensure STRICT tables are strictly typed.
// The other two are type hints for booleans and time.
switch types[i] {
case "INT", "INTEGER":
scans[i] = _INT
case "REAL":
scans[i] = _REAL
case "TEXT":
scans[i] = _TEXT
case "BLOB":
scans[i] = _BLOB
case "BOOLEAN":
scans[i] = _BOOL
case "DATE", "TIME", "DATETIME", "TIMESTAMP":
scans[i] = _TIME
}
}
}
r.nulls = nulls
r.types = types
r.scans = scans
}
}
func (r *rows) declType(index int) string {
if r.types == nil {
count := r.Stmt.ColumnCount()
types := make([]string, count)
for i := range types {
types[i] = strings.ToUpper(r.Stmt.ColumnDeclType(i))
}
r.types = types
}
return r.types[index]
}
func (r *rows) ColumnTypeDatabaseTypeName(index int) string {
r.loadColumnMetadata()
decltype := r.types[index]
if len := len(decltype); len > 0 && decltype[len-1] == ')' {
if i := strings.LastIndexByte(decltype, '('); i >= 0 {
decltype = decltype[:i]
}
}
return strings.TrimSpace(decltype)
}
func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) {
r.loadColumnMetadata()
if r.nulls[index] {
return false, true
}
return true, false
}
func (r *rows) ColumnTypeScanType(index int) (typ reflect.Type) {
r.loadColumnMetadata()
scan := r.scans[index]
if r.Stmt.Busy() {
// SQLite is dynamically typed and we now have a row.
// Always use the type of the value itself,
// unless the scan type is more specific
// and can scan the actual value.
val := scantype(r.Stmt.ColumnType(index))
useValType := true
switch {
case scan == _TIME && val != _BLOB && val != _NULL:
t := r.Stmt.ColumnTime(index, r.tmRead)
useValType = t == time.Time{}
case scan == _BOOL && val == _INT:
i := r.Stmt.ColumnInt64(index)
useValType = i != 0 && i != 1
case scan == _BLOB && val == _NULL:
useValType = false
}
if useValType {
scan = val
}
}
switch scan {
case _INT:
return reflect.TypeFor[int64]()
case _REAL:
return reflect.TypeFor[float64]()
case _TEXT:
return reflect.TypeFor[string]()
case _BLOB:
return reflect.TypeFor[[]byte]()
case _BOOL:
return reflect.TypeFor[bool]()
case _TIME:
return reflect.TypeFor[time.Time]()
default:
return reflect.TypeFor[any]()
}
}
func (r *rows) Next(dest []driver.Value) error {
old := r.Stmt.Conn().SetInterrupt(r.ctx)
defer r.Stmt.Conn().SetInterrupt(old)
if !r.Stmt.Step() {
if err := r.Stmt.Err(); err != nil {
return err
}
return io.EOF
}
data := unsafe.Slice((*any)(unsafe.SliceData(dest)), len(dest))
err := r.Stmt.Columns(data...)
for i := range dest {
if t, ok := r.decodeTime(i, dest[i]); ok {
dest[i] = t
}
}
return err
}
func (r *rows) decodeTime(i int, v any) (_ time.Time, ok bool) {
switch v := v.(type) {
case int64, float64:
// could be a time value
case string:
if r.tmWrite != "" && r.tmWrite != time.RFC3339 && r.tmWrite != time.RFC3339Nano {
break
}
t, ok := maybeTime(v)
if ok {
return t, true
}
default:
return
}
switch r.declType(i) {
case "DATE", "TIME", "DATETIME", "TIMESTAMP":
// could be a time value
default:
return
}
t, err := r.tmRead.Decode(v)
return t, err == nil
}
package driver
import (
"database/sql"
"time"
"github.com/ncruces/go-sqlite3"
)
// Savepoint establishes a new transaction savepoint.
//
// https://sqlite.org/lang_savepoint.html
func Savepoint(tx *sql.Tx) sqlite3.Savepoint {
var ctx saveptCtx
tx.ExecContext(&ctx, "")
return ctx.Savepoint
}
// A saveptCtx is never canceled, has no values, and has no deadline.
type saveptCtx struct{ sqlite3.Savepoint }
func (*saveptCtx) Deadline() (deadline time.Time, ok bool) {
// notest
return
}
func (*saveptCtx) Done() <-chan struct{} {
// notest
return nil
}
func (*saveptCtx) Err() error {
// notest
return nil
}
func (*saveptCtx) Value(key any) any {
// notest
return nil
}
package driver
import "time"
// Convert a string in [time.RFC3339Nano] format into a [time.Time]
// if it roundtrips back to the same string.
// This way times can be persisted to, and recovered from, the database,
// but if a string is needed, [database/sql] will recover the same string.
func maybeTime(text string) (_ time.Time, _ bool) {
// Weed out (some) values that can't possibly be
// [time.RFC3339Nano] timestamps.
if len(text) < len("2006-01-02T15:04:05Z") {
return
}
if len(text) > len(time.RFC3339Nano) {
return
}
if text[4] != '-' || text[10] != 'T' || text[16] != ':' {
return
}
// Slow path.
var buf [len(time.RFC3339Nano)]byte
date, err := time.Parse(time.RFC3339Nano, text)
if err == nil && text == string(date.AppendFormat(buf[:0], time.RFC3339Nano)) {
return date, true
}
return
}
package driver
import "database/sql/driver"
func namedValues(args []driver.Value) []driver.NamedValue {
named := make([]driver.NamedValue, len(args))
for i, v := range args {
named[i] = driver.NamedValue{
Ordinal: i + 1,
Value: v,
}
}
return named
}
func notWhitespace(sql string) bool {
const (
code = iota
slash
minus
ccomment
sqlcomment
endcomment
)
state := code
for _, b := range ([]byte)(sql) {
if b == 0 {
break
}
switch state {
case code:
switch b {
case '/':
state = slash
case '-':
state = minus
case ' ', ';', '\t', '\n', '\v', '\f', '\r':
continue
default:
return true
}
case slash:
if b != '*' {
return true
}
state = ccomment
case minus:
if b != '-' {
return true
}
state = sqlcomment
case ccomment:
if b == '*' {
state = endcomment
}
case sqlcomment:
if b == '\n' {
state = code
}
case endcomment:
switch b {
case '/':
state = code
case '*':
state = endcomment
default:
state = ccomment
}
}
}
return state == slash || state == minus
}
// Package embed embeds SQLite into your application.
//
// Importing package embed initializes the [sqlite3.Binary] variable
// with an appropriate build of SQLite:
//
// import _ "github.com/ncruces/go-sqlite3/embed"
package embed
import (
_ "embed"
"unsafe"
"github.com/ncruces/go-sqlite3"
)
//go:embed sqlite3.wasm
var binary string
func init() {
sqlite3.Binary = unsafe.Slice(unsafe.StringData(binary), len(binary))
}
package sqlite3
import (
"errors"
"strings"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Error wraps an SQLite Error Code.
//
// https://sqlite.org/c3ref/errcode.html
type Error struct {
msg string
sql string
code res_t
}
// Code returns the primary error code for this error.
//
// https://sqlite.org/rescode.html
func (e *Error) Code() ErrorCode {
return ErrorCode(e.code)
}
// ExtendedCode returns the extended error code for this error.
//
// https://sqlite.org/rescode.html
func (e *Error) ExtendedCode() ExtendedErrorCode {
return xErrorCode(e.code)
}
// Error implements the error interface.
func (e *Error) Error() string {
var b strings.Builder
b.WriteString(util.ErrorCodeString(uint32(e.code)))
if e.msg != "" {
b.WriteString(": ")
b.WriteString(e.msg)
}
return b.String()
}
// Is tests whether this error matches a given [ErrorCode] or [ExtendedErrorCode].
//
// It makes it possible to do:
//
// if errors.Is(err, sqlite3.BUSY) {
// // ... handle BUSY
// }
func (e *Error) Is(err error) bool {
switch c := err.(type) {
case ErrorCode:
return c == e.Code()
case ExtendedErrorCode:
return c == e.ExtendedCode()
}
return false
}
// As converts this error to an [ErrorCode] or [ExtendedErrorCode].
func (e *Error) As(err any) bool {
switch c := err.(type) {
case *ErrorCode:
*c = e.Code()
return true
case *ExtendedErrorCode:
*c = e.ExtendedCode()
return true
}
return false
}
// Temporary returns true for [BUSY] errors.
func (e *Error) Temporary() bool {
return e.Code() == BUSY
}
// Timeout returns true for [BUSY_TIMEOUT] errors.
func (e *Error) Timeout() bool {
return e.ExtendedCode() == BUSY_TIMEOUT
}
// SQL returns the SQL starting at the token that triggered a syntax error.
func (e *Error) SQL() string {
return e.sql
}
// Error implements the error interface.
func (e ErrorCode) Error() string {
return util.ErrorCodeString(uint32(e))
}
// Temporary returns true for [BUSY] errors.
func (e ErrorCode) Temporary() bool {
return e == BUSY || e == INTERRUPT
}
// ExtendedCode returns the extended error code for this error.
func (e ErrorCode) ExtendedCode() ExtendedErrorCode {
return xErrorCode(e)
}
// Error implements the error interface.
func (e ExtendedErrorCode) Error() string {
return util.ErrorCodeString(uint32(e))
}
// Is tests whether this error matches a given [ErrorCode].
func (e ExtendedErrorCode) Is(err error) bool {
c, ok := err.(ErrorCode)
return ok && c == ErrorCode(e)
}
// As converts this error to an [ErrorCode].
func (e ExtendedErrorCode) As(err any) bool {
c, ok := err.(*ErrorCode)
if ok {
*c = ErrorCode(e)
}
return ok
}
// Temporary returns true for [BUSY] errors.
func (e ExtendedErrorCode) Temporary() bool {
return ErrorCode(e) == BUSY || ErrorCode(e) == INTERRUPT
}
// Timeout returns true for [BUSY_TIMEOUT] errors.
func (e ExtendedErrorCode) Timeout() bool {
return e == BUSY_TIMEOUT
}
// Code returns the primary error code for this error.
func (e ExtendedErrorCode) Code() ErrorCode {
return ErrorCode(e)
}
func errorCode(err error, def ErrorCode) (msg string, code res_t) {
switch code := err.(type) {
case nil:
return "", _OK
case ErrorCode:
return "", res_t(code)
case xErrorCode:
return "", res_t(code)
case *Error:
return code.msg, res_t(code.code)
}
var ecode ErrorCode
var xcode xErrorCode
switch {
case errors.As(err, &xcode):
code = res_t(xcode)
case errors.As(err, &ecode):
code = res_t(ecode)
default:
code = res_t(def)
}
return err.Error(), code
}
// Package array provides the array table-valued SQL function.
//
// https://sqlite.org/carray.html
package array
import (
"fmt"
"reflect"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Register registers the array single-argument, table-valued SQL function.
// The argument must be bound to a Go slice or array of
// ints, floats, bools, strings or byte slices,
// using [sqlite3.BindPointer] or [sqlite3.Pointer].
func Register(db *sqlite3.Conn) error {
return sqlite3.CreateModule(db, "array", nil,
func(db *sqlite3.Conn, _, _, _ string, _ ...string) (array, error) {
err := db.DeclareVTab(`CREATE TABLE x(value, array HIDDEN)`)
return array{}, err
})
}
type array struct{}
func (array) BestIndex(idx *sqlite3.IndexInfo) error {
for i, cst := range idx.Constraint {
if cst.Column == 1 && cst.Op == sqlite3.INDEX_CONSTRAINT_EQ && cst.Usable {
idx.ConstraintUsage[i] = sqlite3.IndexConstraintUsage{
Omit: true,
ArgvIndex: 1,
}
idx.EstimatedCost = 1
idx.EstimatedRows = 100
return nil
}
}
return sqlite3.CONSTRAINT
}
func (array) Open() (sqlite3.VTabCursor, error) {
return &cursor{}, nil
}
type cursor struct {
array reflect.Value
rowID int
}
func (c *cursor) EOF() bool {
return c.rowID >= c.array.Len()
}
func (c *cursor) Next() error {
c.rowID++
return nil
}
func (c *cursor) RowID() (int64, error) {
return int64(c.rowID), nil
}
func (c *cursor) Column(ctx sqlite3.Context, n int) error {
if n != 0 {
return nil
}
v := c.array.Index(c.rowID)
k := v.Kind()
if k == reflect.Interface {
if v.IsNil() {
ctx.ResultNull()
return nil
}
v = v.Elem()
k = v.Kind()
}
switch {
case v.CanInt():
ctx.ResultInt64(v.Int())
case v.CanUint():
i64 := int64(v.Uint())
if i64 < 0 {
return fmt.Errorf("array: integer element overflow:%.0w %d", sqlite3.MISMATCH, v.Uint())
}
ctx.ResultInt64(i64)
case v.CanFloat():
ctx.ResultFloat(v.Float())
case k == reflect.Bool:
ctx.ResultBool(v.Bool())
case k == reflect.String:
ctx.ResultText(v.String())
case (k == reflect.Slice || k == reflect.Array && v.CanAddr()) &&
v.Type().Elem().Kind() == reflect.Uint8:
ctx.ResultBlob(v.Bytes())
default:
return fmt.Errorf("array: unsupported element:%.0w %v", sqlite3.MISMATCH, util.ReflectType(v))
}
return nil
}
func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
array := reflect.ValueOf(arg[0].Pointer())
array, err := indexable(array)
if err != nil {
return err
}
c.array = array
c.rowID = 0
return nil
}
func indexable(v reflect.Value) (reflect.Value, error) {
switch v.Kind() {
case reflect.Slice:
return v, nil
case reflect.Array:
return v, nil
case reflect.Pointer:
if v := v.Elem(); v.Kind() == reflect.Array {
return v, nil
}
}
return v, fmt.Errorf("array: unsupported argument:%.0w %v", sqlite3.MISMATCH, util.ReflectType(v))
}
// Package blobio provides an SQL interface to incremental BLOB I/O.
package blobio
import (
"errors"
"io"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Register registers the SQL functions:
//
// readblob(schema, table, column, rowid, offset, n/writer)
//
// Reads n bytes of a blob, starting at offset.
//
// writeblob(schema, table, column, rowid, offset, data/reader)
//
// Writes data into a blob, at the given offset.
//
// openblob(schema, table, column, rowid, write, callback, args...)
//
// Opens blobs for reading or writing.
// The callback is invoked for each open blob,
// and must be bound to an [OpenCallback],
// using [sqlite3.BindPointer] or [sqlite3.Pointer].
// The optional args will be passed to the callback,
// along with the [sqlite3.Blob] handle.
// The [sqlite3.Blob] handle is only valid during
// the execution of the callback. Callers cannot
// read or write to the handle after the callback
// exits.
//
// https://sqlite.org/c3ref/blob.html
func Register(db *sqlite3.Conn) error {
return errors.Join(
db.CreateFunction("readblob", 6, 0, readblob),
db.CreateFunction("writeblob", 6, 0, writeblob),
db.CreateFunction("openblob", -1, 0, openblob))
}
// OpenCallback is the type for the openblob callback.
type OpenCallback func(*sqlite3.Blob, ...sqlite3.Value) error
func readblob(ctx sqlite3.Context, arg ...sqlite3.Value) {
_ = arg[5] // bounds check
blob, err := getAuxBlob(ctx, arg, false)
if err != nil {
ctx.ResultError(err)
return // notest
}
_, err = blob.Seek(arg[4].Int64(), io.SeekStart)
if err != nil {
ctx.ResultError(err)
return // notest
}
if p, ok := arg[5].Pointer().(io.Writer); ok {
var n int64
n, err = blob.WriteTo(p)
ctx.ResultInt64(n)
} else {
n := arg[5].Int64()
if n <= 0 {
return
}
buf := make([]byte, n)
_, err = io.ReadFull(blob, buf)
ctx.ResultBlob(buf)
}
if err != nil {
ctx.ResultError(err)
return // notest
}
setAuxBlob(ctx, blob, false)
}
func writeblob(ctx sqlite3.Context, arg ...sqlite3.Value) {
_ = arg[5] // bounds check
blob, err := getAuxBlob(ctx, arg, true)
if err != nil {
ctx.ResultError(err)
return // notest
}
_, err = blob.Seek(arg[4].Int64(), io.SeekStart)
if err != nil {
ctx.ResultError(err)
return // notest
}
if p, ok := arg[5].Pointer().(io.Reader); ok {
var n int64
n, err = blob.ReadFrom(p)
ctx.ResultInt64(n)
} else {
_, err = blob.Write(arg[5].RawBlob())
}
if err != nil {
ctx.ResultError(err)
return // notest
}
setAuxBlob(ctx, blob, false)
}
func openblob(ctx sqlite3.Context, arg ...sqlite3.Value) {
if len(arg) < 6 {
ctx.ResultError(util.ErrorString("openblob: wrong number of arguments"))
return
}
blob, err := getAuxBlob(ctx, arg, arg[4].Bool())
if err != nil {
ctx.ResultError(err)
return // notest
}
fn := arg[5].Pointer().(OpenCallback)
err = fn(blob, arg[6:]...)
if err != nil {
ctx.ResultError(err)
return // notest
}
setAuxBlob(ctx, blob, true)
}
func getAuxBlob(ctx sqlite3.Context, arg []sqlite3.Value, write bool) (*sqlite3.Blob, error) {
row := arg[3].Int64()
if blob, ok := ctx.GetAuxData(0).(*sqlite3.Blob); ok {
if err := blob.Reopen(row); errors.Is(err, sqlite3.MISUSE) {
// Blob was closed (db, table, column or write changed).
} else {
return blob, err
}
}
db := arg[0].Text()
table := arg[1].Text()
column := arg[2].Text()
return ctx.Conn().OpenBlob(db, table, column, row, write)
}
func setAuxBlob(ctx sqlite3.Context, blob *sqlite3.Blob, open bool) {
// This ensures the blob is closed if db, table, column or write change.
ctx.SetAuxData(0, blob) // db
ctx.SetAuxData(1, blob) // table
ctx.SetAuxData(2, blob) // column
if open {
ctx.SetAuxData(4, blob) // write
}
}
// Package bloom provides a Bloom filter virtual table.
//
// A Bloom filter is a space-efficient probabilistic data structure
// used to test whether an element is a member of a set.
//
// https://github.com/nalgeon/sqlean/issues/27#issuecomment-1002267134
package bloom
import (
"fmt"
"io"
"math"
"strconv"
"github.com/dchest/siphash"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Register registers the bloom_filter virtual table:
//
// CREATE VIRTUAL TABLE foo USING bloom_filter(nElements, falseProb, kHashes)
func Register(db *sqlite3.Conn) error {
return sqlite3.CreateModule(db, "bloom_filter", create, connect)
}
type bloom struct {
db *sqlite3.Conn
schema string
storage string
prob float64
bytes int64
hashes int
}
func create(db *sqlite3.Conn, _, schema, table string, arg ...string) (_ *bloom, err error) {
b := bloom{
db: db,
schema: schema,
storage: table + "_storage",
}
var nelem int64
if len(arg) > 0 {
nelem, err = strconv.ParseInt(arg[0], 10, 64)
if err != nil {
return nil, err
}
if nelem <= 0 {
return nil, util.ErrorString("bloom: number of elements in filter must be positive")
}
} else {
nelem = 100
}
if len(arg) > 1 {
b.prob, err = strconv.ParseFloat(arg[1], 64)
if err != nil {
return nil, err
}
if b.prob <= 0 || b.prob >= 1 {
return nil, util.ErrorString("bloom: probability must be in the range (0,1)")
}
} else {
b.prob = 0.01
}
if len(arg) > 2 {
b.hashes, err = strconv.Atoi(arg[2])
if err != nil {
return nil, err
}
if b.hashes <= 0 {
return nil, util.ErrorString("bloom: number of hash functions must be positive")
}
} else {
b.hashes = max(1, numHashes(b.prob))
}
b.bytes = numBytes(nelem, b.prob)
err = db.DeclareVTab(
`CREATE TABLE x(present, word HIDDEN NOT NULL PRIMARY KEY) WITHOUT ROWID`)
if err != nil {
return nil, err
}
err = db.Exec(fmt.Sprintf(
`CREATE TABLE %s.%s (data BLOB, p REAL, n INTEGER, m INTEGER, k INTEGER)`,
sqlite3.QuoteIdentifier(b.schema), sqlite3.QuoteIdentifier(b.storage)))
if err != nil {
return nil, err
}
id := db.LastInsertRowID()
defer db.SetLastInsertRowID(id)
err = db.Exec(fmt.Sprintf(
`INSERT INTO %s.%s (rowid, data, p, n, m, k)
VALUES (1, zeroblob(%d), %f, %d, %d, %d)`,
sqlite3.QuoteIdentifier(b.schema), sqlite3.QuoteIdentifier(b.storage),
b.bytes, b.prob, nelem, 8*b.bytes, b.hashes))
if err != nil {
b.Destroy()
return nil, err
}
return &b, nil
}
func connect(db *sqlite3.Conn, _, schema, table string, arg ...string) (_ *bloom, err error) {
b := bloom{
db: db,
schema: schema,
storage: table + "_storage",
}
err = db.DeclareVTab(
`CREATE TABLE x(present, word HIDDEN NOT NULL PRIMARY KEY) WITHOUT ROWID`)
if err != nil {
return nil, err
}
load, _, err := db.Prepare(fmt.Sprintf(
`SELECT m/8, p, k FROM %s.%s WHERE rowid = 1`,
sqlite3.QuoteIdentifier(b.schema), sqlite3.QuoteIdentifier(b.storage)))
if err != nil {
return nil, err
}
defer load.Close()
if !load.Step() {
if err := load.Err(); err != nil {
return nil, err
}
return nil, sqlite3.CORRUPT_VTAB
}
b.bytes = load.ColumnInt64(0)
b.prob = load.ColumnFloat(1)
b.hashes = load.ColumnInt(2)
return &b, nil
}
func (b *bloom) Destroy() error {
return b.db.Exec(fmt.Sprintf(`DROP TABLE %s.%s`,
sqlite3.QuoteIdentifier(b.schema),
sqlite3.QuoteIdentifier(b.storage)))
}
func (b *bloom) Rename(new string) error {
new += "_storage"
err := b.db.Exec(fmt.Sprintf(`ALTER TABLE %s.%s RENAME TO %s`,
sqlite3.QuoteIdentifier(b.schema),
sqlite3.QuoteIdentifier(b.storage),
sqlite3.QuoteIdentifier(new),
))
if err == nil {
b.storage = new
}
return err
}
func (t *bloom) ShadowTables() {
// notest // not meant to be called
}
func (t *bloom) Integrity(schema, table string, flags int) error {
load, _, err := t.db.Prepare(fmt.Sprintf(
`SELECT typeof(data), length(data), p, n, m, k FROM %s.%s WHERE rowid = 1`,
sqlite3.QuoteIdentifier(t.schema), sqlite3.QuoteIdentifier(t.storage)))
if err != nil {
return fmt.Errorf("bloom: %v", err) // can't wrap!
}
defer load.Close()
err = util.ErrorString("bloom: invalid parameters")
if !load.Step() {
return err
}
if t := load.ColumnText(0); t != "blob" {
return err
}
if m := load.ColumnInt64(4); m <= 0 || m%8 != 0 {
return err
} else if load.ColumnInt64(1) != m/8 {
return err
}
if p := load.ColumnFloat(2); p <= 0 || p >= 1 {
return err
}
if n := load.ColumnInt64(3); n <= 0 {
return err
}
if k := load.ColumnInt(5); k <= 0 {
return err
}
return nil
}
func (b *bloom) BestIndex(idx *sqlite3.IndexInfo) error {
for i, cst := range idx.Constraint {
if cst.Usable && cst.Column == 1 &&
cst.Op == sqlite3.INDEX_CONSTRAINT_EQ {
idx.ConstraintUsage[i].ArgvIndex = 1
idx.OrderByConsumed = true
idx.EstimatedRows = 1
idx.EstimatedCost = float64(b.hashes)
idx.IdxFlags = sqlite3.INDEX_SCAN_UNIQUE
return nil
}
}
return sqlite3.CONSTRAINT
}
func (b *bloom) Update(arg ...sqlite3.Value) (rowid int64, err error) {
if arg[0].Type() != sqlite3.NULL {
if len(arg) == 1 {
return 0, util.ErrorString("bloom: elements cannot be deleted")
}
return 0, util.ErrorString("bloom: elements cannot be updated")
}
if arg[2].NoChange() {
return 0, nil
}
blob := arg[2].RawBlob()
f, err := b.db.OpenBlob(b.schema, b.storage, "data", 1, true)
if err != nil {
return 0, err
}
defer f.Close()
for n := range b.hashes {
hash := calcHash(n, blob)
hash %= uint64(b.bytes * 8)
bitpos := byte(hash % 8)
bytepos := int64(hash / 8)
var buf [1]byte
_, err = f.Seek(bytepos, io.SeekStart)
if err != nil {
return 0, err
}
_, err = f.Read(buf[:])
if err != nil {
return 0, err
}
buf[0] |= 1 << bitpos
_, err = f.Seek(bytepos, io.SeekStart)
if err != nil {
return 0, err
}
_, err = f.Write(buf[:])
if err != nil {
return 0, err
}
}
return 0, nil
}
func (b *bloom) Open() (sqlite3.VTabCursor, error) {
return &cursor{bloom: b}, nil
}
type cursor struct {
*bloom
arg sqlite3.Value
eof bool
}
func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
c.eof = false
c.arg = arg[0]
blob := arg[0].RawBlob()
f, err := c.db.OpenBlob(c.schema, c.storage, "data", 1, false)
if err != nil {
return err
}
defer f.Close()
for n := 0; n < c.hashes && !c.eof; n++ {
hash := calcHash(n, blob)
hash %= uint64(c.bytes * 8)
bitpos := byte(hash % 8)
bytepos := int64(hash / 8)
var buf [1]byte
_, err = f.Seek(bytepos, io.SeekStart)
if err != nil {
return err
}
_, err = f.Read(buf[:])
if err != nil {
return err
}
c.eof = buf[0]&(1<<bitpos) == 0
}
return nil
}
func (c *cursor) Column(ctx sqlite3.Context, n int) error {
if ctx.VTabNoChange() {
return nil
}
switch n {
case 0:
ctx.ResultBool(true)
case 1:
ctx.ResultValue(c.arg)
}
return nil
}
func (c *cursor) Next() error {
c.eof = true
return nil
}
func (c *cursor) EOF() bool {
return c.eof
}
func (c *cursor) RowID() (int64, error) {
// notest // WITHOUT ROWID
return 0, nil
}
func calcHash(k int, b []byte) uint64 {
return siphash.Hash(^uint64(k), uint64(k), b)
}
func numHashes(p float64) int {
k := math.Round(-math.Log2(p))
return max(1, int(k))
}
func numBytes(n int64, p float64) int64 {
m := math.Ceil(float64(n) * math.Log(p) / -(math.Ln2 * math.Ln2))
return (int64(m) + 7) / 8
}
// Package closure provides a transitive closure virtual table.
//
// The transitive_closure virtual table finds the transitive closure of
// a parent/child relationship in a real table.
//
// https://sqlite.org/src/doc/tip/ext/misc/closure.c
package closure
import (
"fmt"
"math"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/go-sqlite3/util/sql3util"
)
const (
_COL_ID = 0
_COL_DEPTH = 1
_COL_ROOT = 2
_COL_TABLENAME = 3
_COL_IDCOLUMN = 4
_COL_PARENTCOLUMN = 5
)
// Register registers the transitive_closure virtual table:
//
// CREATE VIRTUAL TABLE temp.closure USING transitive_closure;
func Register(db *sqlite3.Conn) error {
return sqlite3.CreateModule(db, "transitive_closure", nil,
func(db *sqlite3.Conn, _, _, _ string, arg ...string) (*closure, error) {
var (
table string
column string
parent string
done = util.Set[string]{}
)
for _, arg := range arg {
key, val := sql3util.NamedArg(arg)
if done.Contains(key) {
return nil, fmt.Errorf("transitive_closure: more than one %q parameter", key)
}
switch key {
case "tablename":
table = sql3util.Unquote(val)
case "idcolumn":
column = sql3util.Unquote(val)
case "parentcolumn":
parent = sql3util.Unquote(val)
default:
return nil, fmt.Errorf("transitive_closure: unknown %q parameter", key)
}
done.Add(key)
}
err := db.DeclareVTab(`CREATE TABLE x(id,depth,root HIDDEN,tablename HIDDEN,idcolumn HIDDEN,parentcolumn HIDDEN)`)
if err != nil {
return nil, err
}
return &closure{
db: db,
table: table,
column: column,
parent: parent,
}, nil
})
}
type closure struct {
db *sqlite3.Conn
table string
column string
parent string
}
func (c *closure) Destroy() error { return nil }
func (c *closure) BestIndex(idx *sqlite3.IndexInfo) error {
plan := 0
posi := 1
cost := 1e7
for i, cst := range idx.Constraint {
switch {
case !cst.Usable:
continue
case plan&1 == 0 && cst.Column == _COL_ROOT:
switch cst.Op {
case sqlite3.INDEX_CONSTRAINT_EQ:
plan |= 1
cost /= 100
idx.ConstraintUsage[i] = sqlite3.IndexConstraintUsage{
ArgvIndex: 1,
Omit: true,
}
}
case plan&0xf0 == 0 && cst.Column == _COL_DEPTH:
switch cst.Op {
case sqlite3.INDEX_CONSTRAINT_LT, sqlite3.INDEX_CONSTRAINT_LE, sqlite3.INDEX_CONSTRAINT_EQ:
plan |= posi << 4
cost /= 5
posi += 1
idx.ConstraintUsage[i].ArgvIndex = posi
if cst.Op == sqlite3.INDEX_CONSTRAINT_LT {
plan |= 2
}
}
case plan&0xf00 == 0 && cst.Column == _COL_TABLENAME:
switch cst.Op {
case sqlite3.INDEX_CONSTRAINT_EQ:
plan |= posi << 8
cost /= 5
posi += 1
idx.ConstraintUsage[i] = sqlite3.IndexConstraintUsage{
ArgvIndex: posi,
Omit: true,
}
}
case plan&0xf000 == 0 && cst.Column == _COL_IDCOLUMN:
switch cst.Op {
case sqlite3.INDEX_CONSTRAINT_EQ:
plan |= posi << 12
posi += 1
idx.ConstraintUsage[i] = sqlite3.IndexConstraintUsage{
ArgvIndex: posi,
Omit: true,
}
}
case plan&0xf0000 == 0 && cst.Column == _COL_PARENTCOLUMN:
switch cst.Op {
case sqlite3.INDEX_CONSTRAINT_EQ:
plan |= posi << 16
posi += 1
idx.ConstraintUsage[i] = sqlite3.IndexConstraintUsage{
ArgvIndex: posi,
Omit: true,
}
}
}
}
if plan&1 == 0 ||
c.table == "" && plan&0xf00 == 0 ||
c.column == "" && plan&0xf000 == 0 ||
c.parent == "" && plan&0xf0000 == 0 {
return sqlite3.CONSTRAINT
}
idx.EstimatedCost = cost
idx.IdxNum = plan
return nil
}
func (c *closure) Open() (sqlite3.VTabCursor, error) {
return &cursor{closure: c}, nil
}
type cursor struct {
*closure
nodes []node
}
type node struct {
id int64
depth int
}
func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
root := arg[0].Int64()
maxDepth := math.MaxInt
if idxNum&0xf0 != 0 {
maxDepth = arg[(idxNum>>4)&0xf].Int()
if idxNum&2 != 0 {
maxDepth -= 1
}
}
table := c.table
if idxNum&0xf00 != 0 {
table = arg[(idxNum>>8)&0xf].Text()
}
column := c.column
if idxNum&0xf000 != 0 {
column = arg[(idxNum>>12)&0xf].Text()
}
parent := c.parent
if idxNum&0xf0000 != 0 {
parent = arg[(idxNum>>16)&0xf].Text()
}
sql := fmt.Sprintf(
`SELECT %[1]s.%[2]s FROM %[1]s WHERE %[1]s.%[3]s=?`,
sqlite3.QuoteIdentifier(table),
sqlite3.QuoteIdentifier(column),
sqlite3.QuoteIdentifier(parent),
)
stmt, _, err := c.db.Prepare(sql)
if err != nil {
return err
}
defer stmt.Close()
c.nodes = []node{{root, 0}}
set := util.Set[int64]{}
set.Add(root)
for i := range c.nodes {
curr := c.nodes[i]
if curr.depth >= maxDepth {
continue
}
if err := stmt.BindInt64(1, curr.id); err != nil {
return err
}
for stmt.Step() {
if stmt.ColumnType(0) == sqlite3.INTEGER {
next := stmt.ColumnInt64(0)
if !set.Contains(next) {
set.Add(next)
c.nodes = append(c.nodes, node{next, curr.depth + 1})
}
}
}
if err := stmt.Reset(); err != nil {
return err
}
}
return nil
}
func (c *cursor) Column(ctx sqlite3.Context, n int) error {
switch n {
case _COL_ID:
ctx.ResultInt64(c.nodes[0].id)
case _COL_DEPTH:
ctx.ResultInt(c.nodes[0].depth)
case _COL_TABLENAME:
ctx.ResultText(c.table)
case _COL_IDCOLUMN:
ctx.ResultText(c.column)
case _COL_PARENTCOLUMN:
ctx.ResultText(c.parent)
}
return nil
}
func (c *cursor) Next() error {
c.nodes = c.nodes[1:]
return nil
}
func (c *cursor) EOF() bool {
return len(c.nodes) == 0
}
func (c *cursor) RowID() (int64, error) {
return c.nodes[0].id, nil
}
package csv
import (
"fmt"
"strconv"
"github.com/ncruces/go-sqlite3/util/sql3util"
)
func uintArg(key, val string) (int, error) {
i, err := strconv.ParseUint(val, 10, 15)
if err != nil {
return 0, fmt.Errorf("csv: invalid %q parameter: %s", key, val)
}
return int(i), nil
}
func boolArg(key, val string) (bool, error) {
if val == "" {
return true, nil
}
b, ok := sql3util.ParseBool(val)
if ok {
return b, nil
}
return false, fmt.Errorf("csv: invalid %q parameter: %s", key, val)
}
func runeArg(key, val string) (rune, error) {
r, _, tail, err := strconv.UnquoteChar(sql3util.Unquote(val), 0)
if tail != "" || err != nil {
return 0, fmt.Errorf("csv: invalid %q parameter: %s", key, val)
}
return r, nil
}
// Package csv provides a CSV virtual table.
//
// The CSV virtual table reads RFC 4180 formatted comma-separated values,
// and returns that content as if it were rows and columns of an SQL table.
//
// https://sqlite.org/csv.html
package csv
import (
"bufio"
"encoding/csv"
"fmt"
"io"
"io/fs"
"strconv"
"strings"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/go-sqlite3/util/osutil"
"github.com/ncruces/go-sqlite3/util/sql3util"
)
// Register registers the CSV virtual table.
// If a filename is specified, [os.Open] is used to open the file.
func Register(db *sqlite3.Conn) error {
return RegisterFS(db, osutil.FS{})
}
// RegisterFS registers the CSV virtual table.
// If a filename is specified, fsys is used to open the file.
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error {
declare := func(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err error) {
var (
filename string
data string
schema string
header bool
columns int = -1
comma rune = ','
comment rune
done = util.Set[string]{}
)
for _, arg := range arg {
key, val := sql3util.NamedArg(arg)
if done.Contains(key) {
return nil, fmt.Errorf("csv: more than one %q parameter", key)
}
switch key {
case "filename":
filename = sql3util.Unquote(val)
case "data":
data = sql3util.Unquote(val)
case "schema":
schema = sql3util.Unquote(val)
case "header":
header, err = boolArg(key, val)
case "columns":
columns, err = uintArg(key, val)
case "comma":
comma, err = runeArg(key, val)
case "comment":
comment, err = runeArg(key, val)
default:
return nil, fmt.Errorf("csv: unknown %q parameter", key)
}
if err != nil {
return nil, err
}
done.Add(key)
}
if (filename == "") == (data == "") {
return nil, util.ErrorString(`csv: must specify either "filename" or "data" but not both`)
}
t := &table{
fsys: fsys,
name: filename,
data: data,
comma: comma,
comment: comment,
header: header,
}
if schema == "" {
var row []string
if header || columns < 0 {
csv, c, err := t.newReader()
defer c.Close()
if err != nil {
return nil, err
}
row, err = csv.Read()
if err != nil {
return nil, err
}
}
schema = getSchema(header, columns, row)
} else {
t.typs, err = getColumnAffinities(schema)
if err != nil {
return nil, err
}
}
err = db.DeclareVTab(schema)
if err == nil {
err = db.VTabConfig(sqlite3.VTAB_DIRECTONLY)
}
if err != nil {
return nil, err
}
return t, nil
}
return sqlite3.CreateModule(db, "csv", declare, declare)
}
type table struct {
fsys fs.FS
name string
data string
typs []affinity
comma rune
comment rune
header bool
}
func (t *table) BestIndex(idx *sqlite3.IndexInfo) error {
idx.EstimatedCost = 1e6
return nil
}
func (t *table) Open() (sqlite3.VTabCursor, error) {
return &cursor{table: t}, nil
}
func (t *table) Rename(new string) error {
return nil
}
func (t *table) Integrity(schema, table string, flags int) error {
if flags&1 != 0 {
return nil
}
csv, c, err := t.newReader()
if err != nil {
return err
}
defer c.Close()
_, err = csv.ReadAll()
return err
}
func (t *table) newReader() (*csv.Reader, io.Closer, error) {
var r io.Reader
var c io.Closer
if t.name != "" {
f, err := t.fsys.Open(t.name)
if err != nil {
return nil, f, err
}
buf := bufio.NewReader(f)
bom, err := buf.Peek(3)
if err != nil {
return nil, f, err
}
if string(bom) == "\xEF\xBB\xBF" {
buf.Discard(3)
}
r = buf
c = f
} else {
r = strings.NewReader(t.data)
c = io.NopCloser(r)
}
csv := csv.NewReader(r)
csv.ReuseRecord = true
csv.Comma = t.comma
csv.Comment = t.comment
return csv, c, nil
}
type cursor struct {
table *table
closer io.Closer
csv *csv.Reader
row []string
rowID int64
}
func (c *cursor) Close() (err error) {
if c.closer != nil {
err = c.closer.Close()
c.closer = nil
}
return err
}
func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
err := c.Close()
if err != nil {
return err
}
c.csv, c.closer, err = c.table.newReader()
if err != nil {
return err
}
if c.table.header {
err = c.Next() // skip header
if err != nil {
return err
}
}
c.rowID = 0
return c.Next()
}
func (c *cursor) Next() (err error) {
c.rowID++
c.row, err = c.csv.Read()
if err != io.EOF {
return err
}
return nil
}
func (c *cursor) EOF() bool {
return c.row == nil
}
func (c *cursor) RowID() (int64, error) {
return c.rowID, nil
}
func (c *cursor) Column(ctx sqlite3.Context, col int) error {
if col < len(c.row) {
typ := text
if col < len(c.table.typs) {
typ = c.table.typs[col]
}
txt := c.row[col]
if txt == "" && typ != text {
return nil
}
switch typ {
case numeric, integer:
if strings.TrimLeft(txt, "+-0123456789") == "" {
if i, err := strconv.ParseInt(txt, 10, 64); err == nil {
ctx.ResultInt64(i)
return nil
}
}
fallthrough
case real:
if strings.TrimLeft(txt, "+-.0123456789Ee") == "" {
if f, err := strconv.ParseFloat(txt, 64); err == nil {
ctx.ResultFloat(f)
return nil
}
}
fallthrough
default:
}
ctx.ResultText(txt)
}
return nil
}
package csv
import (
"strconv"
"strings"
"github.com/ncruces/go-sqlite3"
)
func getSchema(header bool, columns int, row []string) string {
var sep string
var str strings.Builder
str.WriteString("CREATE TABLE x(")
if 0 <= columns && columns < len(row) {
row = row[:columns]
}
for i, f := range row {
str.WriteString(sep)
if header && f != "" {
str.WriteString(sqlite3.QuoteIdentifier(f))
} else {
str.WriteString("c")
str.WriteString(strconv.Itoa(i + 1))
}
str.WriteString(" TEXT")
sep = ","
}
for i := len(row); i < columns; i++ {
str.WriteString(sep)
str.WriteString("c")
str.WriteString(strconv.Itoa(i + 1))
str.WriteString(" TEXT")
sep = ","
}
str.WriteByte(')')
return str.String()
}
package csv
import (
"strings"
"github.com/ncruces/go-sqlite3/util/sql3util"
)
type affinity byte
const (
blob affinity = 0
text affinity = 1
numeric affinity = 2
integer affinity = 3
real affinity = 4
)
func getColumnAffinities(schema string) ([]affinity, error) {
tab, err := sql3util.ParseTable(schema)
if err != nil {
return nil, err
}
columns := tab.Columns
types := make([]affinity, len(columns))
for i, col := range columns {
types[i] = getAffinity(col.Type)
}
return types, nil
}
func getAffinity(declType string) affinity {
// https://sqlite.org/datatype3.html#determination_of_column_affinity
if declType == "" {
return blob
}
name := strings.ToUpper(declType)
if strings.Contains(name, "INT") {
return integer
}
if strings.Contains(name, "CHAR") || strings.Contains(name, "CLOB") || strings.Contains(name, "TEXT") {
return text
}
if strings.Contains(name, "BLOB") {
return blob
}
if strings.Contains(name, "REAL") || strings.Contains(name, "FLOA") || strings.Contains(name, "DOUB") {
return real
}
return numeric
}
// Package fileio provides SQL functions to read, write and list files.
//
// https://sqlite.org/src/doc/tip/ext/misc/fileio.c
package fileio
import (
"errors"
"fmt"
"io/fs"
"os"
"github.com/ncruces/go-sqlite3"
)
// Register registers SQL functions readfile, writefile, lsmode,
// and the table-valued function fsdir.
func Register(db *sqlite3.Conn) error {
return RegisterFS(db, nil)
}
// Register registers SQL functions readfile, lsmode,
// and the table-valued function fsdir;
// fsys will be used to read files and list directories.
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error {
var err error
if fsys == nil {
err = db.CreateFunction("writefile", -1, sqlite3.DIRECTONLY, writefile)
}
return errors.Join(err,
db.CreateFunction("readfile", 1, sqlite3.DIRECTONLY, readfile(fsys)),
db.CreateFunction("lsmode", 1, sqlite3.DETERMINISTIC, lsmode),
sqlite3.CreateModule(db, "fsdir", nil, func(db *sqlite3.Conn, _, _, _ string, _ ...string) (fsdir, error) {
err := db.DeclareVTab(`CREATE TABLE x(name,mode,mtime TIMESTAMP,data,path HIDDEN,dir HIDDEN)`)
if err == nil {
err = db.VTabConfig(sqlite3.VTAB_DIRECTONLY)
}
return fsdir{fsys}, err
}))
}
func lsmode(ctx sqlite3.Context, arg ...sqlite3.Value) {
ctx.ResultText(fs.FileMode(arg[0].Int()).String())
}
func readfile(fsys fs.FS) sqlite3.ScalarFunction {
return func(ctx sqlite3.Context, arg ...sqlite3.Value) {
var err error
var data []byte
if fsys != nil {
data, err = fs.ReadFile(fsys, arg[0].Text())
} else {
data, err = os.ReadFile(arg[0].Text())
}
switch {
case err == nil:
ctx.ResultBlob(data)
case !errors.Is(err, fs.ErrNotExist):
ctx.ResultError(fmt.Errorf("readfile: %w", err)) // notest
}
}
}
package fileio
import (
"io/fs"
"iter"
"os"
"path"
"path/filepath"
"strings"
"github.com/ncruces/go-sqlite3"
)
const (
_COL_NAME = 0
_COL_MODE = 1
_COL_TIME = 2
_COL_DATA = 3
_COL_ROOT = 4
_COL_BASE = 5
)
type fsdir struct{ fsys fs.FS }
func (d fsdir) BestIndex(idx *sqlite3.IndexInfo) error {
var root, base bool
for i, cst := range idx.Constraint {
switch cst.Column {
case _COL_ROOT:
if !cst.Usable || cst.Op != sqlite3.INDEX_CONSTRAINT_EQ {
return sqlite3.CONSTRAINT
}
idx.ConstraintUsage[i] = sqlite3.IndexConstraintUsage{
Omit: true,
ArgvIndex: 1,
}
root = true
case _COL_BASE:
if !cst.Usable || cst.Op != sqlite3.INDEX_CONSTRAINT_EQ {
return sqlite3.CONSTRAINT
}
idx.ConstraintUsage[i] = sqlite3.IndexConstraintUsage{
Omit: true,
ArgvIndex: 2,
}
base = true
}
}
if !root {
return sqlite3.CONSTRAINT
}
if base {
idx.EstimatedCost = 10
} else {
idx.EstimatedCost = 100
}
return nil
}
func (d fsdir) Open() (sqlite3.VTabCursor, error) {
return &cursor{fsdir: d}, nil
}
type cursor struct {
fsdir
base string
next func() (entry, bool)
stop func()
curr entry
eof bool
rowID int64
}
type entry struct {
fs.DirEntry
err error
path string
}
func (c *cursor) Close() error {
if c.stop != nil {
c.stop()
}
return nil
}
func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
if err := c.Close(); err != nil {
return err
}
root := arg[0].Text()
if len(arg) > 1 {
base := arg[1].Text()
if c.fsys != nil {
root = path.Join(base, root)
base = path.Clean(base) + "/"
} else {
root = filepath.Join(base, root)
base = filepath.Clean(base) + string(filepath.Separator)
}
c.base = base
}
c.next, c.stop = iter.Pull(func(yield func(entry) bool) {
walkDir := func(path string, d fs.DirEntry, err error) error {
if yield(entry{d, err, path}) {
return nil
}
return fs.SkipAll
}
if c.fsys != nil {
fs.WalkDir(c.fsys, root, walkDir)
} else {
filepath.WalkDir(root, walkDir)
}
})
c.eof = false
c.rowID = 0
return c.Next()
}
func (c *cursor) Next() error {
curr, ok := c.next()
c.curr = curr
c.eof = !ok
c.rowID++
return c.curr.err
}
func (c *cursor) EOF() bool {
return c.eof
}
func (c *cursor) RowID() (int64, error) {
return c.rowID, nil
}
func (c *cursor) Column(ctx sqlite3.Context, n int) error {
switch n {
case _COL_NAME:
name := strings.TrimPrefix(c.curr.path, c.base)
ctx.ResultText(name)
case _COL_MODE:
i, err := c.curr.Info()
if err != nil {
return err
}
ctx.ResultInt64(int64(i.Mode()))
case _COL_TIME:
i, err := c.curr.Info()
if err != nil {
return err
}
ctx.ResultTime(i.ModTime(), sqlite3.TimeFormatUnixFrac)
case _COL_DATA:
switch typ := c.curr.Type(); {
case typ.IsRegular():
var data []byte
var err error
if c.fsys != nil {
data, err = fs.ReadFile(c.fsys, c.curr.path)
} else {
data, err = os.ReadFile(c.curr.path)
}
if err != nil {
return err
}
ctx.ResultBlob(data)
case typ&fs.ModeSymlink != 0 && c.fsys == nil:
t, err := os.Readlink(c.curr.path)
if err != nil {
return err
}
ctx.ResultText(t)
}
}
return nil
}
package fileio
import (
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"time"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/go-sqlite3/util/fsutil"
)
func writefile(ctx sqlite3.Context, arg ...sqlite3.Value) {
if len(arg) < 2 || len(arg) > 4 {
ctx.ResultError(util.ErrorString("writefile: wrong number of arguments"))
return
}
file := arg[0].Text()
var mode fs.FileMode
if len(arg) > 2 {
mode = fsutil.FileModeFromValue(arg[2])
}
n, err := createFileAndDir(file, mode, arg[1])
if err != nil {
if len(arg) > 2 {
ctx.ResultError(fmt.Errorf("writefile: %w", err)) // notest
}
return
}
if mode&fs.ModeSymlink == 0 {
if len(arg) > 2 {
err := os.Chmod(file, mode.Perm())
if err != nil {
ctx.ResultError(fmt.Errorf("writefile: %w", err))
return // notest
}
}
if len(arg) > 3 {
mtime := arg[3].Time(sqlite3.TimeFormatUnixFrac)
err := os.Chtimes(file, time.Time{}, mtime)
if err != nil {
ctx.ResultError(fmt.Errorf("writefile: %w", err))
return // notest
}
}
}
if mode.IsRegular() {
ctx.ResultInt(n)
}
}
func createFileAndDir(path string, mode fs.FileMode, data sqlite3.Value) (int, error) {
n, err := createFile(path, mode, data)
if errors.Is(err, fs.ErrNotExist) {
if err := os.MkdirAll(filepath.Dir(path), 0777); err == nil {
return createFile(path, mode, data)
}
}
return n, err
}
func createFile(path string, mode fs.FileMode, data sqlite3.Value) (int, error) {
if mode.IsRegular() {
blob := data.RawBlob()
return len(blob), os.WriteFile(path, blob, fixPerm(mode, 0666))
}
if mode.IsDir() {
err := os.Mkdir(path, fixPerm(mode, 0777))
if errors.Is(err, fs.ErrExist) {
s, err := os.Lstat(path)
if err == nil && s.IsDir() {
return 0, nil
}
}
return 0, err
}
if mode&fs.ModeSymlink != 0 {
return 0, os.Symlink(data.Text(), path)
}
return 0, fmt.Errorf("invalid mode: %v", mode)
}
func fixPerm(mode fs.FileMode, def fs.FileMode) fs.FileMode {
if mode.Perm() == 0 {
return def
}
return mode.Perm()
}
package hash
import (
"crypto"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
func blake2sFunc(ctx sqlite3.Context, arg ...sqlite3.Value) {
hashFunc(ctx, arg[0], crypto.BLAKE2s_256)
}
func blake2bFunc(ctx sqlite3.Context, arg ...sqlite3.Value) {
size := 512
if len(arg) > 1 {
size = arg[1].Int()
}
switch size {
case 256:
hashFunc(ctx, arg[0], crypto.BLAKE2b_256)
case 384:
hashFunc(ctx, arg[0], crypto.BLAKE2b_384)
case 512:
hashFunc(ctx, arg[0], crypto.BLAKE2b_512)
default:
ctx.ResultError(util.ErrorString("blake2b: size must be 256, 384, 512"))
}
}
// Package hash provides cryptographic hash functions.
//
// Provided functions:
// - md4(data)
// - md5(data)
// - sha1(data)
// - sha3(data, size) (default size 256)
// - sha224(data)
// - sha256(data, size) (default size 256)
// - sha384(data)
// - sha512(data, size) (default size 512)
// - blake2s(data)
// - blake2b(data, size) (default size 512)
// - ripemd160(data)
//
// Each SQL function will only be registered if the corresponding
// [crypto.Hash] function is available.
// To ensure a specific hash function is available,
// import the implementing package.
package hash
import (
"crypto"
"errors"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Register registers cryptographic hash functions for a database connection.
func Register(db *sqlite3.Conn) error {
const flags = sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
var errs util.ErrorJoiner
if crypto.MD4.Available() {
errs.Join(
db.CreateFunction("md4", 1, flags, md4Func))
}
if crypto.MD5.Available() {
errs.Join(
db.CreateFunction("md5", 1, flags, md5Func))
}
if crypto.SHA1.Available() {
errs.Join(
db.CreateFunction("sha1", 1, flags, sha1Func))
}
if crypto.SHA3_512.Available() {
errs.Join(
db.CreateFunction("sha3", 1, flags, sha3Func),
db.CreateFunction("sha3", 2, flags, sha3Func))
}
if crypto.SHA256.Available() {
errs.Join(
db.CreateFunction("sha224", 1, flags, sha224Func),
db.CreateFunction("sha256", 1, flags, sha256Func),
db.CreateFunction("sha256", 2, flags, sha256Func))
}
if crypto.SHA512.Available() {
errs.Join(
db.CreateFunction("sha384", 1, flags, sha384Func),
db.CreateFunction("sha512", 1, flags, sha512Func),
db.CreateFunction("sha512", 2, flags, sha512Func))
}
if crypto.BLAKE2s_256.Available() {
errs.Join(
db.CreateFunction("blake2s", 1, flags, blake2sFunc))
}
if crypto.BLAKE2b_512.Available() {
errs.Join(
db.CreateFunction("blake2b", 1, flags, blake2bFunc),
db.CreateFunction("blake2b", 2, flags, blake2bFunc))
}
if crypto.RIPEMD160.Available() {
errs.Join(
db.CreateFunction("ripemd160", 1, flags, ripemd160Func))
}
return errors.Join(errs...)
}
func md4Func(ctx sqlite3.Context, arg ...sqlite3.Value) {
hashFunc(ctx, arg[0], crypto.MD4)
}
func md5Func(ctx sqlite3.Context, arg ...sqlite3.Value) {
hashFunc(ctx, arg[0], crypto.MD5)
}
func sha1Func(ctx sqlite3.Context, arg ...sqlite3.Value) {
hashFunc(ctx, arg[0], crypto.SHA1)
}
func ripemd160Func(ctx sqlite3.Context, arg ...sqlite3.Value) {
hashFunc(ctx, arg[0], crypto.RIPEMD160)
}
func hashFunc(ctx sqlite3.Context, arg sqlite3.Value, fn crypto.Hash) {
var data []byte
switch arg.Type() {
case sqlite3.NULL:
return
case sqlite3.BLOB:
data = arg.RawBlob()
default:
data = arg.RawText()
}
h := fn.New()
h.Write(data)
ctx.ResultBlob(h.Sum(nil))
}
package hash
import (
"crypto"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
func sha224Func(ctx sqlite3.Context, arg ...sqlite3.Value) {
hashFunc(ctx, arg[0], crypto.SHA224)
}
func sha384Func(ctx sqlite3.Context, arg ...sqlite3.Value) {
hashFunc(ctx, arg[0], crypto.SHA384)
}
func sha256Func(ctx sqlite3.Context, arg ...sqlite3.Value) {
size := 256
if len(arg) > 1 {
size = arg[1].Int()
}
switch size {
case 224:
hashFunc(ctx, arg[0], crypto.SHA224)
case 256:
hashFunc(ctx, arg[0], crypto.SHA256)
default:
ctx.ResultError(util.ErrorString("sha256: size must be 224, 256"))
}
}
func sha512Func(ctx sqlite3.Context, arg ...sqlite3.Value) {
size := 512
if len(arg) > 1 {
size = arg[1].Int()
}
switch size {
case 224:
hashFunc(ctx, arg[0], crypto.SHA512_224)
case 256:
hashFunc(ctx, arg[0], crypto.SHA512_256)
case 384:
hashFunc(ctx, arg[0], crypto.SHA384)
case 512:
hashFunc(ctx, arg[0], crypto.SHA512)
default:
ctx.ResultError(util.ErrorString("sha512: size must be 224, 256, 384, 512"))
}
}
package hash
import (
"crypto"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
func sha3Func(ctx sqlite3.Context, arg ...sqlite3.Value) {
size := 256
if len(arg) > 1 {
size = arg[1].Int()
}
switch size {
case 224:
hashFunc(ctx, arg[0], crypto.SHA3_224)
case 256:
hashFunc(ctx, arg[0], crypto.SHA3_256)
case 384:
hashFunc(ctx, arg[0], crypto.SHA3_384)
case 512:
hashFunc(ctx, arg[0], crypto.SHA3_512)
default:
ctx.ResultError(util.ErrorString("sha3: size must be 224, 256, 384, 512"))
}
}
// Package ipaddr provides functions to manipulate IPs and CIDRs.
//
// It provides the following functions:
// - ipcontains(prefix, ip)
// - ipoverlaps(prefix1, prefix2)
// - ipfamily(ip/prefix)
// - iphost(ip/prefix)
// - ipmasklen(prefix)
// - ipnetwork(prefix)
package ipaddr
import (
"errors"
"net/netip"
"github.com/ncruces/go-sqlite3"
)
// Register IP/CIDR functions for a database connection.
func Register(db *sqlite3.Conn) error {
const flags = sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
return errors.Join(
db.CreateFunction("ipcontains", 2, flags, contains),
db.CreateFunction("ipoverlaps", 2, flags, overlaps),
db.CreateFunction("ipfamily", 1, flags, family),
db.CreateFunction("iphost", 1, flags, host),
db.CreateFunction("ipmasklen", 1, flags, masklen),
db.CreateFunction("ipnetwork", 1, flags, network))
}
func contains(ctx sqlite3.Context, arg ...sqlite3.Value) {
prefix, err := netip.ParsePrefix(arg[0].Text())
if err != nil {
ctx.ResultError(err)
return // notest
}
addr, err := netip.ParseAddr(arg[1].Text())
if err != nil {
ctx.ResultError(err)
return // notest
}
ctx.ResultBool(prefix.Contains(addr))
}
func overlaps(ctx sqlite3.Context, arg ...sqlite3.Value) {
prefix1, err := netip.ParsePrefix(arg[0].Text())
if err != nil {
ctx.ResultError(err)
return // notest
}
prefix2, err := netip.ParsePrefix(arg[0].Text())
if err != nil {
ctx.ResultError(err)
return // notest
}
ctx.ResultBool(prefix1.Overlaps(prefix2))
}
func family(ctx sqlite3.Context, arg ...sqlite3.Value) {
addr, err := addr(arg[0].Text())
if err != nil {
ctx.ResultError(err)
return // notest
}
switch {
case addr.Is4():
ctx.ResultInt(4)
case addr.Is6():
ctx.ResultInt(6)
}
}
func host(ctx sqlite3.Context, arg ...sqlite3.Value) {
addr, err := addr(arg[0].Text())
if err != nil {
ctx.ResultError(err)
return // notest
}
buf, _ := addr.MarshalText()
ctx.ResultRawText(buf)
}
func masklen(ctx sqlite3.Context, arg ...sqlite3.Value) {
prefix, err := netip.ParsePrefix(arg[0].Text())
if err != nil {
ctx.ResultError(err)
return // notest
}
ctx.ResultInt(prefix.Bits())
}
func network(ctx sqlite3.Context, arg ...sqlite3.Value) {
prefix, err := netip.ParsePrefix(arg[0].Text())
if err != nil {
ctx.ResultError(err)
return // notest
}
buf, _ := prefix.Masked().MarshalText()
ctx.ResultRawText(buf)
}
func addr(text string) (netip.Addr, error) {
addr, err := netip.ParseAddr(text)
if err != nil {
if prefix, err := netip.ParsePrefix(text); err == nil {
return prefix.Addr(), nil
}
if addrpt, err := netip.ParseAddrPort(text); err == nil {
return addrpt.Addr(), nil
}
}
return addr, err
}
// Package lines provides a virtual table to read data line-by-line.
//
// It is particularly useful for line-oriented datasets,
// like [ndjson] or [JSON Lines],
// when paired with SQLite's JSON support.
//
// https://github.com/asg017/sqlite-lines
//
// [ndjson]: https://ndjson.org/
// [JSON Lines]: https://jsonlines.org/
package lines
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"io/fs"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/util/osutil"
)
// Register registers the lines and lines_read table-valued functions.
// The lines function reads from a database blob or text.
// The lines_read function reads from a file or an [io.Reader].
// If a filename is specified, [os.Open] is used to open the file.
func Register(db *sqlite3.Conn) error {
return RegisterFS(db, osutil.FS{})
}
// RegisterFS registers the lines and lines_read table-valued functions.
// The lines function reads from a database blob or text.
// The lines_read function reads from a file or an [io.Reader].
// If a filename is specified, fsys is used to open the file.
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error {
return errors.Join(
sqlite3.CreateModule(db, "lines", nil,
func(db *sqlite3.Conn, _, _, _ string, _ ...string) (lines, error) {
err := db.DeclareVTab(`CREATE TABLE x(line TEXT, data HIDDEN, delim HIDDEN)`)
if err == nil {
err = db.VTabConfig(sqlite3.VTAB_INNOCUOUS)
}
return lines{}, err
}),
sqlite3.CreateModule(db, "lines_read", nil,
func(db *sqlite3.Conn, _, _, _ string, _ ...string) (lines, error) {
err := db.DeclareVTab(`CREATE TABLE x(line TEXT, data HIDDEN, delim HIDDEN)`)
if err == nil {
err = db.VTabConfig(sqlite3.VTAB_DIRECTONLY)
}
return lines{fsys}, err
}))
}
type lines struct {
fsys fs.FS
}
func (l lines) BestIndex(idx *sqlite3.IndexInfo) (err error) {
err = sqlite3.CONSTRAINT
for i, cst := range idx.Constraint {
if !cst.Usable || cst.Op != sqlite3.INDEX_CONSTRAINT_EQ {
continue
}
switch cst.Column {
case 1:
idx.ConstraintUsage[i] = sqlite3.IndexConstraintUsage{
Omit: true,
ArgvIndex: 1,
}
idx.EstimatedCost = 1e6
idx.EstimatedRows = 100
err = nil
case 2:
idx.ConstraintUsage[i] = sqlite3.IndexConstraintUsage{
Omit: true,
ArgvIndex: 2,
}
}
}
return err
}
func (l lines) Open() (sqlite3.VTabCursor, error) {
if l.fsys != nil {
return &reader{fsys: l.fsys}, nil
} else {
return &buffer{}, nil
}
}
type cursor struct {
line []byte
rowID int64
eof bool
delim byte
}
func (c *cursor) EOF() bool {
return c.eof
}
func (c *cursor) RowID() (int64, error) {
return c.rowID, nil
}
func (c *cursor) Column(ctx sqlite3.Context, n int) error {
if n == 0 {
ctx.ResultRawText(c.line)
}
return nil
}
type reader struct {
fsys fs.FS
reader *bufio.Reader
closer io.Closer
cursor
}
func (c *reader) Close() (err error) {
if c.closer != nil {
err = c.closer.Close()
c.closer = nil
}
return err
}
func (c *reader) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
if err := c.Close(); err != nil {
return err
}
var r io.Reader
typ := arg[0].Type()
switch typ {
case sqlite3.NULL:
if p, ok := arg[0].Pointer().(io.Reader); ok {
r = p
}
case sqlite3.TEXT:
f, err := c.fsys.Open(arg[0].Text())
if err != nil {
return err
}
r = f
}
if r == nil {
return fmt.Errorf("lines: unsupported argument:%.0w %v", sqlite3.MISMATCH, typ)
}
c.delim = '\n'
if len(arg) > 1 {
b := arg[1].RawText()
if len(b) != 1 {
return fmt.Errorf("lines: delimiter must be a single byte%.0w", sqlite3.MISMATCH)
}
c.delim = b[0]
}
c.reader = bufio.NewReader(r)
c.closer, _ = r.(io.Closer)
c.rowID = 0
return c.Next()
}
func (c *reader) Next() (err error) {
c.line = c.line[:0]
for more := true; more; {
var line []byte
if c.delim == '\n' {
line, more, err = c.reader.ReadLine()
} else {
line, err = c.reader.ReadSlice(c.delim)
more = err == bufio.ErrBufferFull
}
c.line = append(c.line, line...)
}
if err == io.EOF {
c.eof = true
err = nil
}
c.rowID++
return err
}
type buffer struct {
data []byte
cursor
}
func (c *buffer) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
typ := arg[0].Type()
switch typ {
case sqlite3.TEXT:
c.data = arg[0].RawText()
case sqlite3.BLOB:
c.data = arg[0].RawBlob()
default:
return fmt.Errorf("lines: unsupported argument:%.0w %v", sqlite3.MISMATCH, typ)
}
c.delim = '\n'
if len(arg) > 1 {
b := arg[1].RawText()
if len(b) != 1 {
return fmt.Errorf("lines: delimiter must be a single byte%.0w", sqlite3.MISMATCH)
}
c.delim = b[0]
}
c.rowID = 0
return c.Next()
}
func (c *buffer) Next() error {
i := bytes.IndexByte(c.data, c.delim)
j := i + 1
switch {
case i < 0:
i = len(c.data)
j = i
case i > 0 && c.delim == '\n' && c.data[i-1] == '\r':
i--
}
c.eof = len(c.data) == 0
c.line = c.data[:i]
c.data = c.data[j:]
c.rowID++
return nil
}
// Package pivot implements a pivot virtual table.
//
// https://github.com/jakethaw/pivot_vtab
package pivot
import (
"errors"
"fmt"
"strings"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Register registers the pivot virtual table.
func Register(db *sqlite3.Conn) error {
return sqlite3.CreateModule(db, "pivot", declare, declare)
}
type table struct {
db *sqlite3.Conn
scan string
cell string
keys []string
cols []*sqlite3.Value
}
func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (ret *table, err error) {
if len(arg) != 3 {
return nil, fmt.Errorf("pivot: wrong number of arguments")
}
t := &table{db: db}
defer func() {
if ret == nil {
t.Close()
}
}()
var sep string
var create strings.Builder
create.WriteString("CREATE TABLE x(")
// Row key query.
t.scan = "SELECT * FROM\n" + arg[0]
stmt, _, err := db.Prepare(t.scan)
if err != nil {
return nil, err
}
defer stmt.Close()
t.keys = make([]string, stmt.ColumnCount())
for i := range t.keys {
name := sqlite3.QuoteIdentifier(stmt.ColumnName(i))
t.keys[i] = name
create.WriteString(sep)
create.WriteString(name)
sep = ","
}
stmt.Close()
// Column definition query.
stmt, _, err = db.Prepare("SELECT * FROM\n" + arg[1])
if err != nil {
return nil, err
}
if stmt.ColumnCount() != 2 {
return nil, util.ErrorString("pivot: column definition query expects 2 result columns")
}
for stmt.Step() {
name := sqlite3.QuoteIdentifier(stmt.ColumnText(1))
t.cols = append(t.cols, stmt.ColumnValue(0).Dup())
create.WriteString(",")
create.WriteString(name)
}
stmt.Close()
// Pivot cell query.
t.cell = "SELECT * FROM\n" + arg[2]
stmt, _, err = db.Prepare(t.cell)
if err != nil {
return nil, err
}
if stmt.ColumnCount() != 1 {
return nil, util.ErrorString("pivot: cell query expects 1 result columns")
}
if stmt.BindCount() != len(t.keys)+1 {
return nil, fmt.Errorf("pivot: cell query expects %d bound parameters", len(t.keys)+1)
}
create.WriteByte(')')
err = db.DeclareVTab(create.String())
if err != nil {
return nil, err
}
return t, nil
}
func (t *table) Close() error {
var errs []error
for _, c := range t.cols {
errs = append(errs, c.Close())
}
return errors.Join(errs...)
}
func (t *table) BestIndex(idx *sqlite3.IndexInfo) error {
var idxStr strings.Builder
idxStr.WriteString(t.scan)
argvIndex := 1
sep := " WHERE "
for i, cst := range idx.Constraint {
if !cst.Usable || !(0 <= cst.Column && cst.Column < len(t.keys)) {
continue
}
op := operator(cst.Op)
if op == "" {
continue
}
idxStr.WriteString(sep)
idxStr.WriteString(t.keys[cst.Column])
idxStr.WriteString(" ")
idxStr.WriteString(op)
idxStr.WriteString(" ?")
idx.ConstraintUsage[i] = sqlite3.IndexConstraintUsage{
ArgvIndex: argvIndex,
Omit: true,
}
sep = " AND "
argvIndex++
}
sep = " ORDER BY "
idx.OrderByConsumed = true
for _, ord := range idx.OrderBy {
if !(0 <= ord.Column && ord.Column < len(t.keys)) {
idx.OrderByConsumed = false
continue
}
idxStr.WriteString(sep)
idxStr.WriteString(t.keys[ord.Column])
idxStr.WriteString(" COLLATE ")
idxStr.WriteString(idx.Collation(ord.Column))
if ord.Desc {
idxStr.WriteString(" DESC")
}
sep = ","
}
idx.EstimatedCost = 1e9 / float64(argvIndex)
idx.IdxStr = idxStr.String()
return nil
}
func (t *table) Open() (sqlite3.VTabCursor, error) {
return &cursor{table: t}, nil
}
func (t *table) Rename(new string) error {
return nil
}
type cursor struct {
table *table
scan *sqlite3.Stmt
cell *sqlite3.Stmt
rowID int64
}
func (c *cursor) Close() error {
return errors.Join(c.scan.Close(), c.cell.Close())
}
func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
err := c.scan.Close()
if err != nil {
return err
}
c.scan, _, err = c.table.db.Prepare(idxStr)
if err != nil {
return err
}
for i, arg := range arg {
err := c.scan.BindValue(i+1, arg)
if err != nil {
return err
}
}
if c.cell == nil {
c.cell, _, err = c.table.db.Prepare(c.table.cell)
if err != nil {
return err
}
}
c.rowID = 0
return c.Next()
}
func (c *cursor) Next() error {
if c.scan.Step() {
count := c.scan.ColumnCount()
for i := range count {
err := c.cell.BindValue(i+1, c.scan.ColumnValue(i))
if err != nil {
return err
}
}
c.rowID++
}
return c.scan.Err()
}
func (c *cursor) EOF() bool {
return !c.scan.Busy()
}
func (c *cursor) RowID() (int64, error) {
return c.rowID, nil
}
func (c *cursor) Column(ctx sqlite3.Context, col int) error {
count := c.scan.ColumnCount()
if col < count {
ctx.ResultValue(c.scan.ColumnValue(col))
return nil
}
err := c.cell.BindValue(count+1, *c.table.cols[col-count])
if err != nil {
return err
}
if c.cell.Step() {
ctx.ResultValue(c.cell.ColumnValue(0))
}
return c.cell.Reset()
}
func operator(op sqlite3.IndexConstraintOp) string {
switch op {
case sqlite3.INDEX_CONSTRAINT_EQ:
return "="
case sqlite3.INDEX_CONSTRAINT_LT:
return "<"
case sqlite3.INDEX_CONSTRAINT_GT:
return ">"
case sqlite3.INDEX_CONSTRAINT_LE:
return "<="
case sqlite3.INDEX_CONSTRAINT_GE:
return ">="
case sqlite3.INDEX_CONSTRAINT_NE:
return "<>"
case sqlite3.INDEX_CONSTRAINT_MATCH:
return "MATCH"
case sqlite3.INDEX_CONSTRAINT_LIKE:
return "LIKE"
case sqlite3.INDEX_CONSTRAINT_GLOB:
return "GLOB"
case sqlite3.INDEX_CONSTRAINT_REGEXP:
return "REGEXP"
case sqlite3.INDEX_CONSTRAINT_IS, sqlite3.INDEX_CONSTRAINT_ISNULL:
return "IS"
case sqlite3.INDEX_CONSTRAINT_ISNOT, sqlite3.INDEX_CONSTRAINT_ISNOTNULL:
return "IS NOT"
default:
return ""
}
}
// Package regexp provides additional regular expression functions.
//
// It provides the following Unicode aware functions:
// - regexp_like(text, pattern),
// - regexp_count(text, pattern [, start]),
// - regexp_instr(text, pattern [, start [, N [, endoption [, subexpr ]]]]),
// - regexp_substr(text, pattern [, start [, N [, subexpr ]]]),
// - regexp_replace(text, pattern, replacement [, start [, N ]]),
// - and a REGEXP operator.
//
// The implementation uses Go [regexp/syntax] for regular expressions.
//
// https://github.com/nalgeon/sqlean/blob/main/docs/regexp.md
package regexp
import (
"errors"
"regexp"
"regexp/syntax"
"strings"
"unicode/utf8"
"github.com/ncruces/go-sqlite3"
)
// Register registers Unicode aware functions for a database connection.
func Register(db *sqlite3.Conn) error {
const flags = sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
return errors.Join(
db.CreateFunction("regexp", 2, flags, regex),
db.CreateFunction("regexp_like", 2, flags, regexLike),
db.CreateFunction("regexp_count", 2, flags, regexCount),
db.CreateFunction("regexp_count", 3, flags, regexCount),
db.CreateFunction("regexp_instr", 2, flags, regexInstr),
db.CreateFunction("regexp_instr", 3, flags, regexInstr),
db.CreateFunction("regexp_instr", 4, flags, regexInstr),
db.CreateFunction("regexp_instr", 5, flags, regexInstr),
db.CreateFunction("regexp_instr", 6, flags, regexInstr),
db.CreateFunction("regexp_substr", 2, flags, regexSubstr),
db.CreateFunction("regexp_substr", 3, flags, regexSubstr),
db.CreateFunction("regexp_substr", 4, flags, regexSubstr),
db.CreateFunction("regexp_substr", 5, flags, regexSubstr),
db.CreateFunction("regexp_replace", 3, flags, regexReplace),
db.CreateFunction("regexp_replace", 4, flags, regexReplace),
db.CreateFunction("regexp_replace", 5, flags, regexReplace))
}
// GlobPrefix returns a GLOB for a regular expression
// appropriate to take advantage of the [LIKE optimization]
// in a query such as:
//
// SELECT column WHERE column GLOB :glob_prefix AND column REGEXP :regexp
//
// [LIKE optimization]: https://sqlite.org/optoverview.html#the_like_optimization
func GlobPrefix(expr string) string {
re, err := syntax.Parse(expr, syntax.Perl)
if err != nil {
return "" // no match possible
}
prog, err := syntax.Compile(re.Simplify())
if err != nil {
return "" // notest
}
i := &prog.Inst[prog.Start]
var empty syntax.EmptyOp
loop1:
for {
switch i.Op {
case syntax.InstFail:
return "" // notest
case syntax.InstCapture, syntax.InstNop:
// skip
case syntax.InstEmptyWidth:
empty |= syntax.EmptyOp(i.Arg)
default:
break loop1
}
i = &prog.Inst[i.Out]
}
if empty&syntax.EmptyBeginText == 0 {
return "*" // not anchored
}
var glob strings.Builder
loop2:
for {
switch i.Op {
case syntax.InstFail:
return "" // notest
case syntax.InstCapture, syntax.InstEmptyWidth, syntax.InstNop:
// skip
case syntax.InstRune, syntax.InstRune1:
if len(i.Rune) != 1 || syntax.Flags(i.Arg)&syntax.FoldCase != 0 {
break loop2
}
switch r := i.Rune[0]; r {
case '*', '?', '[', utf8.RuneError:
break loop2
default:
glob.WriteRune(r)
}
default:
break loop2
}
i = &prog.Inst[i.Out]
}
glob.WriteByte('*')
return glob.String()
}
func load(ctx sqlite3.Context, arg []sqlite3.Value, i int) (*regexp.Regexp, error) {
re, ok := ctx.GetAuxData(i).(*regexp.Regexp)
if !ok {
re, ok = arg[i].Pointer().(*regexp.Regexp)
if !ok {
r, err := regexp.Compile(arg[i].Text())
if err != nil {
return nil, err
}
re = r
}
ctx.SetAuxData(i, re)
}
return re, nil
}
func regex(ctx sqlite3.Context, arg ...sqlite3.Value) {
re, err := load(ctx, arg, 0)
if err != nil {
ctx.ResultError(err)
return // notest
}
text := arg[1].RawText()
ctx.ResultBool(re.Match(text))
}
func regexLike(ctx sqlite3.Context, arg ...sqlite3.Value) {
re, err := load(ctx, arg, 1)
if err != nil {
ctx.ResultError(err)
return // notest
}
text := arg[0].RawText()
ctx.ResultBool(re.Match(text))
}
func regexCount(ctx sqlite3.Context, arg ...sqlite3.Value) {
re, err := load(ctx, arg, 1)
if err != nil {
ctx.ResultError(err)
return // notest
}
text := arg[0].RawText()
if len(arg) > 2 {
pos := arg[2].Int()
text = text[skip(text, pos):]
}
ctx.ResultInt(len(re.FindAll(text, -1)))
}
func regexSubstr(ctx sqlite3.Context, arg ...sqlite3.Value) {
re, err := load(ctx, arg, 1)
if err != nil {
ctx.ResultError(err)
return // notest
}
text := arg[0].RawText()
var pos, n, subexpr int
if len(arg) > 2 {
pos = arg[2].Int()
}
if len(arg) > 3 {
n = arg[3].Int()
}
if len(arg) > 4 {
subexpr = arg[4].Int()
}
loc := regexFind(re, text, pos, n, subexpr)
if loc != nil {
ctx.ResultRawText(text[loc[0]:loc[1]])
}
}
func regexInstr(ctx sqlite3.Context, arg ...sqlite3.Value) {
re, err := load(ctx, arg, 1)
if err != nil {
ctx.ResultError(err)
return // notest
}
text := arg[0].RawText()
var pos, n, end, subexpr int
if len(arg) > 2 {
pos = arg[2].Int()
}
if len(arg) > 3 {
n = arg[3].Int()
}
if len(arg) > 4 && arg[4].Bool() {
end = 1
}
if len(arg) > 5 {
subexpr = arg[5].Int()
}
loc := regexFind(re, text, pos, n, subexpr)
if loc != nil {
ctx.ResultInt(loc[end] + 1)
}
}
func regexReplace(ctx sqlite3.Context, arg ...sqlite3.Value) {
re, err := load(ctx, arg, 1)
if err != nil {
ctx.ResultError(err)
return // notest
}
repl := arg[2].RawText()
text := arg[0].RawText()
var pos, n int
if len(arg) > 3 {
pos = arg[3].Int()
}
if len(arg) > 4 {
n = arg[4].Int()
}
res := text
pos = skip(text, pos)
if n > 0 {
all := re.FindAllSubmatchIndex(text[pos:], n)
if n <= len(all) {
loc := all[n-1]
res = text[:pos+loc[0]]
res = re.Expand(res, repl, text[pos:], loc)
res = append(res, text[pos+loc[1]:]...)
}
} else {
res = append(text[:pos], re.ReplaceAll(text[pos:], repl)...)
}
ctx.ResultRawText(res)
}
func regexFind(re *regexp.Regexp, text []byte, pos, n, subexpr int) (loc []int) {
pos = skip(text, pos)
text = text[pos:]
if n <= 1 {
if subexpr == 0 {
loc = re.FindIndex(text)
} else {
loc = re.FindSubmatchIndex(text)
}
} else {
if subexpr == 0 {
all := re.FindAllIndex(text, n)
if n <= len(all) {
loc = all[n-1]
}
} else {
all := re.FindAllSubmatchIndex(text, n)
if n <= len(all) {
loc = all[n-1]
}
}
}
if 2+2*subexpr <= len(loc) {
loc = loc[2*subexpr : 2+2*subexpr]
loc[0] += pos
loc[1] += pos
return loc
}
return nil
}
func skip(text []byte, start int) int {
for pos := range string(text) {
if start--; start <= 0 {
return pos
}
}
return len(text)
}
// Package serdes provides functions to (de)serialize databases.
package serdes
import (
"io"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/vfs"
)
const vfsName = "github.com/ncruces/go-sqlite3/ext/serdes.sliceVFS"
func init() {
vfs.Register(vfsName, sliceVFS{})
}
var fileToOpen = make(chan *sliceFile, 1)
// Serialize backs up a database into a byte slice.
//
// https://sqlite.org/c3ref/serialize.html
func Serialize(db *sqlite3.Conn, schema string) ([]byte, error) {
var file sliceFile
fileToOpen <- &file
err := db.Backup(schema, "file:serdes.db?vfs="+vfsName)
return file.data, err
}
// Deserialize restores a database from a byte slice,
// DESTROYING any contents previously stored in schema.
//
// To non-destructively open a database from a byte slice,
// consider alternatives like the ["reader"] or ["memdb"] VFSes.
//
// This differs from the similarly named SQLite API
// in that it DOES NOT disconnect from schema
// to reopen as an in-memory database.
//
// https://sqlite.org/c3ref/deserialize.html
//
// ["memdb"]: https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs/memdb
// ["reader"]: https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs/readervfs
func Deserialize(db *sqlite3.Conn, schema string, data []byte) error {
fileToOpen <- &sliceFile{data}
return db.Restore(schema, "file:serdes.db?vfs="+vfsName)
}
type sliceVFS struct{}
func (sliceVFS) Open(name string, flags vfs.OpenFlag) (vfs.File, vfs.OpenFlag, error) {
if flags&vfs.OPEN_MAIN_DB == 0 || name != "serdes.db" {
return nil, flags, sqlite3.CANTOPEN
}
select {
case file := <-fileToOpen:
return file, flags | vfs.OPEN_MEMORY, nil
default:
return nil, flags, sqlite3.MISUSE
}
}
func (sliceVFS) Delete(name string, dirSync bool) error {
// notest // OPEN_MEMORY
return sqlite3.IOERR_DELETE
}
func (sliceVFS) Access(name string, flag vfs.AccessFlag) (bool, error) {
return name == "serdes.db", nil
}
func (sliceVFS) FullPathname(name string) (string, error) {
return name, nil
}
type sliceFile struct{ data []byte }
func (f *sliceFile) ReadAt(b []byte, off int64) (n int, err error) {
if d := f.data; off < int64(len(d)) {
n = copy(b, d[off:])
}
if n == 0 {
err = io.EOF
}
return
}
func (f *sliceFile) WriteAt(b []byte, off int64) (n int, err error) {
if d := f.data; off > int64(len(d)) {
f.data = append(d, make([]byte, off-int64(len(d)))...)
}
d := append(f.data[:off], b...)
if len(d) > len(f.data) {
f.data = d
}
return len(b), nil
}
func (f *sliceFile) Size() (int64, error) {
return int64(len(f.data)), nil
}
func (f *sliceFile) Truncate(size int64) error {
if d := f.data; size < int64(len(d)) {
f.data = d[:size]
}
return nil
}
func (f *sliceFile) SizeHint(size int64) error {
if d := f.data; size > int64(len(d)) {
f.data = append(d, make([]byte, size-int64(len(d)))...)
}
return nil
}
func (*sliceFile) Close() error { return nil }
func (*sliceFile) Sync(flag vfs.SyncFlag) error { return nil }
func (*sliceFile) Lock(lock vfs.LockLevel) error { return nil }
func (*sliceFile) Unlock(lock vfs.LockLevel) error { return nil }
func (*sliceFile) CheckReservedLock() (bool, error) {
// notest // OPEN_MEMORY
return false, nil
}
func (*sliceFile) SectorSize() int {
// notest // IOCAP_POWERSAFE_OVERWRITE
return 0
}
func (*sliceFile) DeviceCharacteristics() vfs.DeviceCharacteristic {
return vfs.IOCAP_ATOMIC |
vfs.IOCAP_SAFE_APPEND |
vfs.IOCAP_SEQUENTIAL |
vfs.IOCAP_POWERSAFE_OVERWRITE |
vfs.IOCAP_SUBPAGE_READ
}
// Package statement defines table-valued functions using SQL.
//
// It can be used to create "parametrized views":
// pre-packaged queries that can be parametrized at query execution time.
//
// https://github.com/0x09/sqlite-statement-vtab
package statement
import (
"encoding/json"
"errors"
"strconv"
"strings"
"unsafe"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Register registers the statement virtual table.
func Register(db *sqlite3.Conn) error {
return sqlite3.CreateModule(db, "statement", declare, declare)
}
type table struct {
stmt *sqlite3.Stmt
sql string
inuse bool
}
func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (*table, error) {
if len(arg) != 1 {
return nil, util.ErrorString("statement: wrong number of arguments")
}
sql := "SELECT * FROM\n" + arg[0]
stmt, _, err := db.PrepareFlags(sql, sqlite3.PREPARE_PERSISTENT)
if err != nil {
return nil, err
}
var sep string
var str strings.Builder
str.WriteString("CREATE TABLE x(")
outputs := stmt.ColumnCount()
for i := range outputs {
name := sqlite3.QuoteIdentifier(stmt.ColumnName(i))
str.WriteString(sep)
str.WriteString(name)
str.WriteString(" ")
str.WriteString(stmt.ColumnDeclType(i))
sep = ","
}
inputs := stmt.BindCount()
for i := 1; i <= inputs; i++ {
str.WriteString(sep)
name := stmt.BindName(i)
if name == "" {
str.WriteString("[")
str.WriteString(strconv.Itoa(i))
str.WriteString("] HIDDEN")
} else {
str.WriteString(sqlite3.QuoteIdentifier(name[1:]))
str.WriteString(" HIDDEN")
}
sep = ","
}
str.WriteByte(')')
err = db.DeclareVTab(str.String())
if err != nil {
stmt.Close()
return nil, err
}
return &table{sql: sql, stmt: stmt}, nil
}
func (t *table) Close() error {
return t.stmt.Close()
}
func (t *table) BestIndex(idx *sqlite3.IndexInfo) error {
idx.EstimatedCost = 1000
var argvIndex = 1
var needIndex bool
var listIndex []int
outputs := t.stmt.ColumnCount()
for i, cst := range idx.Constraint {
// Skip if this is a constraint on one of our output columns.
if cst.Column < outputs {
continue
}
// A given query plan is only usable if all provided input columns
// are usable and have equal constraints only.
if !cst.Usable || cst.Op != sqlite3.INDEX_CONSTRAINT_EQ {
return sqlite3.CONSTRAINT
}
// The non-zero argvIdx values must be contiguous.
// If they're not, build a list and serialize it through IdxStr.
nextIndex := cst.Column - outputs + 1
idx.ConstraintUsage[i] = sqlite3.IndexConstraintUsage{
ArgvIndex: argvIndex,
Omit: true,
}
if nextIndex != argvIndex {
needIndex = true
}
listIndex = append(listIndex, nextIndex)
argvIndex++
}
if needIndex {
buf, err := json.Marshal(listIndex)
if err != nil {
return err
}
idx.IdxStr = unsafe.String(&buf[0], len(buf))
}
return nil
}
func (t *table) Open() (_ sqlite3.VTabCursor, err error) {
stmt := t.stmt
if !t.inuse {
t.inuse = true
} else {
stmt, _, err = t.stmt.Conn().Prepare(t.sql)
if err != nil {
return nil, err
}
}
return &cursor{table: t, stmt: stmt}, nil
}
func (t *table) Rename(new string) error {
return nil
}
type cursor struct {
table *table
stmt *sqlite3.Stmt
arg []sqlite3.Value
rowID int64
}
func (c *cursor) Close() error {
if c.stmt == c.table.stmt {
c.table.inuse = false
return errors.Join(
c.stmt.Reset(),
c.stmt.ClearBindings())
}
return c.stmt.Close()
}
func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
err := errors.Join(
c.stmt.Reset(),
c.stmt.ClearBindings())
if err != nil {
return err
}
var list []int
if idxStr != "" {
buf := unsafe.Slice(unsafe.StringData(idxStr), len(idxStr))
err := json.Unmarshal(buf, &list)
if err != nil {
return err
}
}
for i, arg := range arg {
param := i + 1
if list != nil {
param = list[i]
}
err := c.stmt.BindValue(param, arg)
if err != nil {
return err
}
}
c.arg = append(c.arg[:0], arg...)
c.rowID = 0
return c.Next()
}
func (c *cursor) Next() error {
if c.stmt.Step() {
c.rowID++
}
return c.stmt.Err()
}
func (c *cursor) EOF() bool {
return !c.stmt.Busy()
}
func (c *cursor) RowID() (int64, error) {
return c.rowID, nil
}
func (c *cursor) Column(ctx sqlite3.Context, col int) error {
switch outputs := c.stmt.ColumnCount(); {
case col < outputs:
ctx.ResultValue(c.stmt.ColumnValue(col))
case col-outputs < len(c.arg):
ctx.ResultValue(c.arg[col-outputs])
}
return nil
}
package stats
import "github.com/ncruces/go-sqlite3"
const (
every = iota
some
)
func newBoolean(kind int) sqlite3.AggregateConstructor {
return func() sqlite3.AggregateFunction { return &boolean{kind: kind} }
}
type boolean struct {
count int
total int
kind int
}
func (b *boolean) Value(ctx sqlite3.Context) {
if b.kind == every {
ctx.ResultBool(b.count == b.total)
} else {
ctx.ResultBool(b.count > 0)
}
}
func (b *boolean) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
a := arg[0]
if a.Bool() {
b.count++
}
if a.Type() != sqlite3.NULL {
b.total++
}
}
func (b *boolean) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
a := arg[0]
if a.Bool() {
b.count--
}
if a.Type() != sqlite3.NULL {
b.total--
}
}
package stats
// https://en.wikipedia.org/wiki/Kahan_summation_algorithm
type kahan struct{ hi, lo float64 }
func (k *kahan) add(x float64) {
y := k.lo + x
t := k.hi + y
k.lo = y - (t - k.hi)
k.hi = t
}
func (k *kahan) sub(x float64) {
y := k.lo - x
t := k.hi + y
k.lo = y - (t - k.hi)
k.hi = t
}
package stats
import (
"unsafe"
"github.com/ncruces/go-sqlite3"
)
func newMode() sqlite3.AggregateFunction {
return &mode{}
}
type mode struct {
ints counter[int64]
reals counter[float64]
texts counter[string]
blobs counter[string]
}
func (m mode) Value(ctx sqlite3.Context) {
var (
max = 0
typ = sqlite3.NULL
i64 int64
f64 float64
str string
)
for k, v := range m.ints {
if v > max || v == max && k < i64 {
typ = sqlite3.INTEGER
max = v
i64 = k
}
}
f64 = float64(i64)
for k, v := range m.reals {
if v > max || v == max && k < f64 {
typ = sqlite3.FLOAT
max = v
f64 = k
}
}
for k, v := range m.texts {
if v > max || v == max && typ == sqlite3.TEXT && k < str {
typ = sqlite3.TEXT
max = v
str = k
}
}
for k, v := range m.blobs {
if v > max || v == max && typ == sqlite3.BLOB && k < str {
typ = sqlite3.BLOB
max = v
str = k
}
}
switch typ {
case sqlite3.INTEGER:
ctx.ResultInt64(i64)
case sqlite3.FLOAT:
ctx.ResultFloat(f64)
case sqlite3.TEXT:
ctx.ResultText(str)
case sqlite3.BLOB:
ctx.ResultBlob(unsafe.Slice(unsafe.StringData(str), len(str)))
}
}
func (b *mode) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
switch arg[0].Type() {
case sqlite3.INTEGER:
b.ints.add(arg[0].Int64())
case sqlite3.FLOAT:
b.reals.add(arg[0].Float())
case sqlite3.TEXT:
b.texts.add(arg[0].Text())
case sqlite3.BLOB:
b.blobs.add(string(arg[0].RawBlob()))
}
}
func (b *mode) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
switch arg[0].Type() {
case sqlite3.INTEGER:
b.ints.del(arg[0].Int64())
case sqlite3.FLOAT:
b.reals.del(arg[0].Float())
case sqlite3.TEXT:
b.texts.del(arg[0].Text())
case sqlite3.BLOB:
b.blobs.del(string(arg[0].RawBlob()))
}
}
type counter[T comparable] map[T]int
func (c *counter[T]) add(k T) {
if (*c) == nil {
(*c) = make(counter[T])
}
(*c)[k]++
}
func (c counter[T]) del(k T) {
switch n := c[k]; n {
default:
c[k] = n - 1
case 1:
delete(c, k)
case 0:
}
}
package stats
import "math"
// Fisher–Pearson skewness and kurtosis using
// Terriberry's algorithm with Kahan summation:
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics
type moments struct {
m1, m2, m3, m4 kahan
n int64
}
func (m moments) mean() float64 {
return m.m1.hi
}
func (m moments) var_pop() float64 {
return m.m2.hi / float64(m.n)
}
func (m moments) var_samp() float64 {
return m.m2.hi / float64(m.n-1) // Bessel's correction
}
func (m moments) stddev_pop() float64 {
return math.Sqrt(m.var_pop())
}
func (m moments) stddev_samp() float64 {
return math.Sqrt(m.var_samp())
}
func (m moments) skewness_pop() float64 {
m2 := m.m2.hi
if div := m2 * m2 * m2; div != 0 {
return m.m3.hi * math.Sqrt(float64(m.n)/div)
}
return math.NaN()
}
func (m moments) skewness_samp() float64 {
n := m.n
// https://mathworks.com/help/stats/skewness.html#f1132178
return m.skewness_pop() * math.Sqrt(float64(n*(n-1))) / float64(n-2)
}
func (m moments) kurtosis_pop() float64 {
return m.raw_kurtosis_pop() - 3
}
func (m moments) raw_kurtosis_pop() float64 {
m2 := m.m2.hi
if div := m2 * m2; div != 0 {
return m.m4.hi * float64(m.n) / div
}
return math.NaN()
}
func (m moments) kurtosis_samp() float64 {
n := m.n
k := math.FMA(m.raw_kurtosis_pop(), float64(n+1), float64(3-3*n))
return k * float64(n-1) / float64((n-2)*(n-3))
}
func (m moments) raw_kurtosis_samp() float64 {
n := m.n
// https://mathworks.com/help/stats/kurtosis.html#f4975293
k := math.FMA(m.raw_kurtosis_pop(), float64(n+1), float64(3-3*n))
return math.FMA(k, float64(n-1)/float64((n-2)*(n-3)), 3)
}
func (m *moments) enqueue(x float64) {
n := m.n + 1
m.n = n
d1 := x - m.m1.hi - m.m1.lo
dn := d1 / float64(n)
d2 := dn * dn
t1 := d1 * dn * float64(n-1)
m.m4.add(t1*d2*float64(n*n-3*n+3) + 6*d2*m.m2.hi - 4*dn*m.m3.hi)
m.m3.add(t1*dn*float64(n-2) - 3*dn*m.m2.hi)
m.m2.add(t1)
m.m1.add(dn)
}
func (m *moments) dequeue(x float64) {
n := m.n - 1
if n <= 0 {
*m = moments{}
return
}
m.n = n
d1 := x - m.m1.hi - m.m1.lo
dn := d1 / float64(n)
d2 := dn * dn
t1 := d1 * dn * float64(n+1)
m.m4.sub(t1*d2*float64(n*n+3*n+3) - 6*d2*m.m2.hi - 4*dn*m.m3.hi)
m.m3.sub(t1*dn*float64(n+2) - 3*dn*m.m2.hi)
m.m2.sub(t1)
m.m1.sub(dn)
}
package stats
import (
"encoding/json"
"fmt"
"math"
"slices"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/sort/quick"
)
// Compatible with:
// https://sqlite.org/src/file/ext/misc/percentile.c
const (
median = iota
percentile_100
percentile_cont
percentile_disc
)
func newPercentile(kind int) sqlite3.AggregateConstructor {
return func() sqlite3.AggregateFunction { return &percentile{kind: kind} }
}
type percentile struct {
nums []float64
arg1 []byte
kind int
}
func (q *percentile) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
a := arg[0]
f := a.Float()
if f != 0.0 || a.NumericType() != sqlite3.NULL {
q.nums = append(q.nums, f)
}
if q.kind != median && q.arg1 == nil {
q.arg1 = append(q.arg1, arg[1].RawText()...)
}
}
func (q *percentile) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
a := arg[0]
f := a.Float()
if f != 0.0 || a.NumericType() != sqlite3.NULL {
i := slices.Index(q.nums, f)
l := len(q.nums) - 1
q.nums[i] = q.nums[l]
q.nums = q.nums[:l]
}
}
func (q *percentile) Value(ctx sqlite3.Context) {
if len(q.nums) == 0 {
return
}
var (
err error
float float64
floats []float64
)
if q.kind == median {
float, err = q.at(0.5)
ctx.ResultFloat(float)
} else if err = json.Unmarshal(q.arg1, &float); err == nil {
float, err = q.at(float)
ctx.ResultFloat(float)
} else if err = json.Unmarshal(q.arg1, &floats); err == nil {
err = q.atMore(floats)
ctx.ResultJSON(floats)
}
if err != nil {
ctx.ResultError(fmt.Errorf("percentile: %w", err)) // notest
}
}
func (q *percentile) at(pos float64) (float64, error) {
if q.kind == percentile_100 {
pos = pos / 100
}
if pos < 0 || pos > 1 {
return 0, util.ErrorString("invalid pos")
}
i, f := math.Modf(pos * float64(len(q.nums)-1))
m0 := quick.Select(q.nums, int(i))
if f == 0 || q.kind == percentile_disc {
return m0, nil
}
m1 := slices.Min(q.nums[int(i)+1:])
return util.Lerp(m0, m1, f), nil
}
func (q *percentile) atMore(pos []float64) error {
for i := range pos {
v, err := q.at(pos[i])
if err != nil {
return err
}
pos[i] = v
}
return nil
}
// Package stats provides aggregate functions for statistics.
//
// Provided functions:
// - var_pop: population variance
// - var_samp: sample variance
// - stddev_pop: population standard deviation
// - stddev_samp: sample standard deviation
// - skewness_pop: Pearson population skewness
// - skewness_samp: Pearson sample skewness
// - kurtosis_pop: Fisher population excess kurtosis
// - kurtosis_samp: Fisher sample excess kurtosis
// - covar_pop: population covariance
// - covar_samp: sample covariance
// - corr: Pearson correlation coefficient
// - regr_r2: correlation coefficient squared
// - regr_avgx: average of the independent variable
// - regr_avgy: average of the dependent variable
// - regr_sxx: sum of the squares of the independent variable
// - regr_syy: sum of the squares of the dependent variable
// - regr_sxy: sum of the products of each pair of variables
// - regr_count: count non-null pairs of variables
// - regr_slope: slope of the least-squares-fit linear equation
// - regr_intercept: y-intercept of the least-squares-fit linear equation
// - regr_json: all regr stats as a JSON object
// - percentile_disc: discrete quantile
// - percentile_cont: continuous quantile
// - percentile: continuous percentile
// - median: middle value
// - mode: most frequent value
// - every: boolean and
// - some: boolean or
//
// These join the [Built-in Aggregate Functions]:
// - count: count rows/values
// - sum: sum values
// - avg: average value
// - min: minimum value
// - max: maximum value
//
// And the [Built-in Window Functions]:
// - rank: rank of the current row with gaps
// - dense_rank: rank of the current row without gaps
// - percent_rank: relative rank of the row
// - cume_dist: cumulative distribution
//
// See: [ANSI SQL Aggregate Functions]
//
// [Built-in Aggregate Functions]: https://sqlite.org/lang_aggfunc.html
// [Built-in Window Functions]: https://sqlite.org/windowfunctions.html#builtins
// [ANSI SQL Aggregate Functions]: https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html
package stats
import (
"errors"
"github.com/ncruces/go-sqlite3"
)
// Register registers statistics functions.
func Register(db *sqlite3.Conn) error {
const flags = sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
const order = sqlite3.SELFORDER1 | flags
return errors.Join(
db.CreateWindowFunction("var_pop", 1, flags, newVariance(var_pop)),
db.CreateWindowFunction("var_samp", 1, flags, newVariance(var_samp)),
db.CreateWindowFunction("stddev_pop", 1, flags, newVariance(stddev_pop)),
db.CreateWindowFunction("stddev_samp", 1, flags, newVariance(stddev_samp)),
db.CreateWindowFunction("skewness_pop", 1, flags, newMoments(skewness_pop)),
db.CreateWindowFunction("skewness_samp", 1, flags, newMoments(skewness_samp)),
db.CreateWindowFunction("kurtosis_pop", 1, flags, newMoments(kurtosis_pop)),
db.CreateWindowFunction("kurtosis_samp", 1, flags, newMoments(kurtosis_samp)),
db.CreateWindowFunction("covar_pop", 2, flags, newCovariance(var_pop)),
db.CreateWindowFunction("covar_samp", 2, flags, newCovariance(var_samp)),
db.CreateWindowFunction("corr", 2, flags, newCovariance(corr)),
db.CreateWindowFunction("regr_r2", 2, flags, newCovariance(regr_r2)),
db.CreateWindowFunction("regr_sxx", 2, flags, newCovariance(regr_sxx)),
db.CreateWindowFunction("regr_syy", 2, flags, newCovariance(regr_syy)),
db.CreateWindowFunction("regr_sxy", 2, flags, newCovariance(regr_sxy)),
db.CreateWindowFunction("regr_avgx", 2, flags, newCovariance(regr_avgx)),
db.CreateWindowFunction("regr_avgy", 2, flags, newCovariance(regr_avgy)),
db.CreateWindowFunction("regr_slope", 2, flags, newCovariance(regr_slope)),
db.CreateWindowFunction("regr_intercept", 2, flags, newCovariance(regr_intercept)),
db.CreateWindowFunction("regr_count", 2, flags, newCovariance(regr_count)),
db.CreateWindowFunction("regr_json", 2, flags, newCovariance(regr_json)),
db.CreateWindowFunction("median", 1, order, newPercentile(median)),
db.CreateWindowFunction("percentile", 2, order, newPercentile(percentile_100)),
db.CreateWindowFunction("percentile_cont", 2, order, newPercentile(percentile_cont)),
db.CreateWindowFunction("percentile_disc", 2, order, newPercentile(percentile_disc)),
db.CreateWindowFunction("every", 1, flags, newBoolean(every)),
db.CreateWindowFunction("some", 1, flags, newBoolean(some)),
db.CreateWindowFunction("mode", 1, order, newMode))
}
const (
var_pop = iota
var_samp
stddev_pop
stddev_samp
skewness_pop
skewness_samp
kurtosis_pop
kurtosis_samp
corr
regr_r2
regr_sxx
regr_syy
regr_sxy
regr_avgx
regr_avgy
regr_slope
regr_intercept
regr_count
regr_json
)
func special(kind int, n int64) (null, zero bool) {
switch kind {
case var_pop, stddev_pop, regr_sxx, regr_syy, regr_sxy:
return n <= 0, n == 1
case regr_avgx, regr_avgy:
return n <= 0, false
case kurtosis_samp:
return n <= 3, false
case skewness_samp:
return n <= 2, false
case skewness_pop:
return n <= 1, n == 2
default:
return n <= 1, false
}
}
func newVariance(kind int) sqlite3.AggregateConstructor {
return func() sqlite3.AggregateFunction { return &variance{kind: kind} }
}
type variance struct {
kind int
welford
}
func (fn *variance) Value(ctx sqlite3.Context) {
switch null, zero := special(fn.kind, fn.n); {
case zero:
ctx.ResultFloat(0)
return
case null:
return
}
var r float64
switch fn.kind {
case var_pop:
r = fn.var_pop()
case var_samp:
r = fn.var_samp()
case stddev_pop:
r = fn.stddev_pop()
case stddev_samp:
r = fn.stddev_samp()
}
ctx.ResultFloat(r)
}
func (fn *variance) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
a := arg[0]
f := a.Float()
if f != 0.0 || a.NumericType() != sqlite3.NULL {
fn.enqueue(f)
}
}
func (fn *variance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
a := arg[0]
f := a.Float()
if f != 0.0 || a.NumericType() != sqlite3.NULL {
fn.dequeue(f)
}
}
func newCovariance(kind int) sqlite3.AggregateConstructor {
return func() sqlite3.AggregateFunction { return &covariance{kind: kind} }
}
type covariance struct {
kind int
welford2
}
func (fn *covariance) Value(ctx sqlite3.Context) {
if fn.kind == regr_count {
ctx.ResultInt64(fn.regr_count())
return
}
switch null, zero := special(fn.kind, fn.n); {
case zero:
ctx.ResultFloat(0)
return
case null:
return
}
var r float64
switch fn.kind {
case var_pop:
r = fn.covar_pop()
case var_samp:
r = fn.covar_samp()
case corr:
r = fn.correlation()
case regr_r2:
r = fn.regr_r2()
case regr_sxx:
r = fn.regr_sxx()
case regr_syy:
r = fn.regr_syy()
case regr_sxy:
r = fn.regr_sxy()
case regr_avgx:
r = fn.regr_avgx()
case regr_avgy:
r = fn.regr_avgy()
case regr_slope:
r = fn.regr_slope()
case regr_intercept:
r = fn.regr_intercept()
case regr_json:
var buf [128]byte
ctx.ResultRawText(fn.regr_json(buf[:0]))
return
}
ctx.ResultFloat(r)
}
func (fn *covariance) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
b, a := arg[1], arg[0] // avoid a bounds check
fa := a.Float()
fb := b.Float()
if true &&
(fa != 0.0 || a.NumericType() != sqlite3.NULL) &&
(fb != 0.0 || b.NumericType() != sqlite3.NULL) {
fn.enqueue(fa, fb)
}
}
func (fn *covariance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
b, a := arg[1], arg[0] // avoid a bounds check
fa := a.Float()
fb := b.Float()
if true &&
(fa != 0.0 || a.NumericType() != sqlite3.NULL) &&
(fb != 0.0 || b.NumericType() != sqlite3.NULL) {
fn.dequeue(fa, fb)
}
}
func newMoments(kind int) sqlite3.AggregateConstructor {
return func() sqlite3.AggregateFunction { return &momentfn{kind: kind} }
}
type momentfn struct {
kind int
moments
}
func (fn *momentfn) Value(ctx sqlite3.Context) {
switch null, zero := special(fn.kind, fn.n); {
case zero:
ctx.ResultFloat(0)
return
case null:
return
}
var r float64
switch fn.kind {
case skewness_pop:
r = fn.skewness_pop()
case skewness_samp:
r = fn.skewness_samp()
case kurtosis_pop:
r = fn.kurtosis_pop()
case kurtosis_samp:
r = fn.kurtosis_samp()
}
ctx.ResultFloat(r)
}
func (fn *momentfn) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
a := arg[0]
f := a.Float()
if f != 0.0 || a.NumericType() != sqlite3.NULL {
fn.enqueue(f)
}
}
func (fn *momentfn) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
a := arg[0]
f := a.Float()
if f != 0.0 || a.NumericType() != sqlite3.NULL {
fn.dequeue(f)
}
}
package stats
import (
"math"
"strconv"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Welford's algorithm with Kahan summation:
// The effect of truncation in statistical computation [van Reeken, AJ 1970]
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
type welford struct {
m1, m2 kahan
n int64
}
func (w welford) mean() float64 {
return w.m1.hi
}
func (w welford) var_pop() float64 {
return w.m2.hi / float64(w.n)
}
func (w welford) var_samp() float64 {
return w.m2.hi / float64(w.n-1) // Bessel's correction
}
func (w welford) stddev_pop() float64 {
return math.Sqrt(w.var_pop())
}
func (w welford) stddev_samp() float64 {
return math.Sqrt(w.var_samp())
}
func (w *welford) enqueue(x float64) {
n := w.n + 1
w.n = n
d1 := x - w.m1.hi - w.m1.lo
w.m1.add(d1 / float64(n))
d2 := x - w.m1.hi - w.m1.lo
w.m2.add(d1 * d2)
}
func (w *welford) dequeue(x float64) {
n := w.n - 1
if n <= 0 {
*w = welford{}
return
}
w.n = n
d1 := x - w.m1.hi - w.m1.lo
w.m1.sub(d1 / float64(n))
d2 := x - w.m1.hi - w.m1.lo
w.m2.sub(d1 * d2)
}
type welford2 struct {
m1y, m2y kahan
m1x, m2x kahan
cov kahan
n int64
}
func (w welford2) covar_pop() float64 {
return w.cov.hi / float64(w.n)
}
func (w welford2) covar_samp() float64 {
return w.cov.hi / float64(w.n-1) // Bessel's correction
}
func (w welford2) correlation() float64 {
return w.cov.hi / math.Sqrt(w.m2y.hi*w.m2x.hi)
}
func (w welford2) regr_avgy() float64 {
return w.m1y.hi
}
func (w welford2) regr_avgx() float64 {
return w.m1x.hi
}
func (w welford2) regr_syy() float64 {
return w.m2y.hi
}
func (w welford2) regr_sxx() float64 {
return w.m2x.hi
}
func (w welford2) regr_sxy() float64 {
return w.cov.hi
}
func (w welford2) regr_count() int64 {
return w.n
}
func (w welford2) regr_slope() float64 {
return w.cov.hi / w.m2x.hi
}
func (w welford2) regr_intercept() float64 {
slope := -w.regr_slope()
hi := math.FMA(slope, w.m1x.hi, w.m1y.hi)
lo := math.FMA(slope, w.m1x.lo, w.m1y.lo)
return hi + lo
}
func (w welford2) regr_r2() float64 {
return w.cov.hi * w.cov.hi / (w.m2y.hi * w.m2x.hi)
}
func (w welford2) regr_json(dst []byte) []byte {
dst = append(dst, `{"count":`...)
dst = strconv.AppendInt(dst, w.regr_count(), 10)
dst = append(dst, `,"avgy":`...)
dst = util.AppendNumber(dst, w.regr_avgy())
dst = append(dst, `,"avgx":`...)
dst = util.AppendNumber(dst, w.regr_avgx())
dst = append(dst, `,"syy":`...)
dst = util.AppendNumber(dst, w.regr_syy())
dst = append(dst, `,"sxx":`...)
dst = util.AppendNumber(dst, w.regr_sxx())
dst = append(dst, `,"sxy":`...)
dst = util.AppendNumber(dst, w.regr_sxy())
dst = append(dst, `,"slope":`...)
dst = util.AppendNumber(dst, w.regr_slope())
dst = append(dst, `,"intercept":`...)
dst = util.AppendNumber(dst, w.regr_intercept())
dst = append(dst, `,"r2":`...)
dst = util.AppendNumber(dst, w.regr_r2())
return append(dst, '}')
}
func (w *welford2) enqueue(y, x float64) {
n := w.n + 1
w.n = n
d1y := y - w.m1y.hi - w.m1y.lo
d1x := x - w.m1x.hi - w.m1x.lo
w.m1y.add(d1y / float64(n))
w.m1x.add(d1x / float64(n))
d2y := y - w.m1y.hi - w.m1y.lo
d2x := x - w.m1x.hi - w.m1x.lo
w.m2y.add(d1y * d2y)
w.m2x.add(d1x * d2x)
w.cov.add(d1y * d2x)
}
func (w *welford2) dequeue(y, x float64) {
n := w.n - 1
if n <= 0 {
*w = welford2{}
return
}
w.n = n
d1y := y - w.m1y.hi - w.m1y.lo
d1x := x - w.m1x.hi - w.m1x.lo
w.m1y.sub(d1y / float64(n))
w.m1x.sub(d1x / float64(n))
d2y := y - w.m1y.hi - w.m1y.lo
d2x := x - w.m1x.hi - w.m1x.lo
w.m2y.sub(d1y * d2y)
w.m2x.sub(d1x * d2x)
w.cov.sub(d1y * d2x)
}
// Package unicode provides an alternative to the SQLite ICU extension.
//
// Like the [ICU extension], it provides Unicode aware:
// - upper() and lower() functions
// - LIKE and REGEXP operators
// - collation sequences
//
// Like PostgreSQL, it also provides:
// - initcap()
// - casefold()
// - normalize()
// - unaccent()
//
// The implementations are not 100% compatible:
// - upper(), lower(), initcap() casefold() use [strings.ToUpper], [strings.ToLower], [strings.Title] and [cases]
// - normalize(), unaccent() use [transform] and [unicode.Mn]
// - the LIKE operator follows [strings.EqualFold] rules
// - the REGEXP operator uses Go [regexp/syntax]
// - collation sequences use [collate]
//
// Expect subtle differences (e.g.) in the handling of Turkish case folding.
//
// [ICU extension]: https://sqlite.org/src/dir/ext/icu
package unicode
import (
"bytes"
"errors"
"regexp"
"strings"
"sync"
"unicode"
"unicode/utf8"
"golang.org/x/text/cases"
"golang.org/x/text/collate"
"golang.org/x/text/language"
"golang.org/x/text/runes"
"golang.org/x/text/transform"
"golang.org/x/text/unicode/norm"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Set RegisterLike to false to not register a Unicode aware LIKE operator.
// Overriding the built-in LIKE operator disables the [LIKE optimization].
//
// [LIKE optimization]: https://sqlite.org/optoverview.html#the_like_optimization
var RegisterLike = true
// Register registers Unicode aware functions for a database connection.
func Register(db *sqlite3.Conn) error {
const flags = sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
var lkfn sqlite3.ScalarFunction
if RegisterLike {
lkfn = like
}
return errors.Join(
db.CreateFunction("like", 2, flags, lkfn),
db.CreateFunction("like", 3, flags, lkfn),
db.CreateFunction("upper", 1, flags, upper),
db.CreateFunction("upper", 2, flags, upper),
db.CreateFunction("lower", 1, flags, lower),
db.CreateFunction("lower", 2, flags, lower),
db.CreateFunction("regexp", 2, flags, regex),
db.CreateFunction("initcap", 1, flags, initcap),
db.CreateFunction("initcap", 2, flags, initcap),
db.CreateFunction("casefold", 1, flags, casefold),
db.CreateFunction("unaccent", 1, flags, unaccent),
db.CreateFunction("normalize", 1, flags, normalize),
db.CreateFunction("normalize", 2, flags, normalize),
db.CreateFunction("icu_load_collation", 2, sqlite3.DIRECTONLY,
func(ctx sqlite3.Context, arg ...sqlite3.Value) {
name := arg[1].Text()
if name == "" {
return
}
err := RegisterCollation(ctx.Conn(), arg[0].Text(), name)
if err != nil {
ctx.ResultError(err)
return // notest
}
}))
}
// RegisterCollation registers a Unicode collation sequence for a database connection.
func RegisterCollation(db *sqlite3.Conn, locale, name string) error {
tag, err := language.Parse(locale)
if err != nil {
return err
}
return db.CreateCollation(name, collate.New(tag).Compare)
}
// RegisterCollationsNeeded registers Unicode collation sequences on demand for a database connection.
func RegisterCollationsNeeded(db *sqlite3.Conn) error {
return db.CollationNeeded(func(db *sqlite3.Conn, name string) {
if tag, err := language.Parse(name); err == nil {
db.CreateCollation(name, collate.New(tag).Compare)
}
})
}
func upper(ctx sqlite3.Context, arg ...sqlite3.Value) {
if len(arg) == 1 {
ctx.ResultRawText(bytes.ToUpper(arg[0].RawText()))
return
}
cs, ok := ctx.GetAuxData(1).(cases.Caser)
if !ok {
t, err := language.Parse(arg[1].Text())
if err != nil {
ctx.ResultError(err)
return // notest
}
cs = cases.Upper(t)
ctx.SetAuxData(1, cs)
}
ctx.ResultRawText(cs.Bytes(arg[0].RawText()))
}
func lower(ctx sqlite3.Context, arg ...sqlite3.Value) {
if len(arg) == 1 {
ctx.ResultRawText(bytes.ToLower(arg[0].RawText()))
return
}
cs, ok := ctx.GetAuxData(1).(cases.Caser)
if !ok {
t, err := language.Parse(arg[1].Text())
if err != nil {
ctx.ResultError(err)
return // notest
}
cs = cases.Lower(t)
ctx.SetAuxData(1, cs)
}
ctx.ResultRawText(cs.Bytes(arg[0].RawText()))
}
func initcap(ctx sqlite3.Context, arg ...sqlite3.Value) {
if len(arg) == 1 {
ctx.ResultRawText(bytes.Title(arg[0].RawText()))
return
}
cs, ok := ctx.GetAuxData(1).(cases.Caser)
if !ok {
t, err := language.Parse(arg[1].Text())
if err != nil {
ctx.ResultError(err)
return // notest
}
cs = cases.Title(t)
ctx.SetAuxData(1, cs)
}
ctx.ResultRawText(cs.Bytes(arg[0].RawText()))
}
func casefold(ctx sqlite3.Context, arg ...sqlite3.Value) {
ctx.ResultRawText(cases.Fold().Bytes(arg[0].RawText()))
}
var unaccentPool = sync.Pool{
New: func() any {
return transform.Chain(norm.NFD, runes.Remove(runes.In(unicode.Mn)), norm.NFC)
},
}
func unaccent(ctx sqlite3.Context, arg ...sqlite3.Value) {
unaccent := unaccentPool.Get().(transform.Transformer)
defer unaccentPool.Put(unaccent)
res, _, err := transform.Bytes(unaccent, arg[0].RawText())
if err != nil {
ctx.ResultError(err) // notest
} else {
ctx.ResultRawText(res)
}
}
func normalize(ctx sqlite3.Context, arg ...sqlite3.Value) {
form := norm.NFC
if len(arg) > 1 {
switch strings.ToUpper(arg[1].Text()) {
case "NFC":
//
case "NFD":
form = norm.NFD
case "NFKC":
form = norm.NFKC
case "NFKD":
form = norm.NFKD
default:
ctx.ResultError(util.ErrorString("unicode: invalid form"))
return
}
}
res, _, err := transform.Bytes(form, arg[0].RawText())
if err != nil {
ctx.ResultError(err) // notest
} else {
ctx.ResultRawText(res)
}
}
func regex(ctx sqlite3.Context, arg ...sqlite3.Value) {
re, ok := ctx.GetAuxData(0).(*regexp.Regexp)
if !ok {
re, ok = arg[0].Pointer().(*regexp.Regexp)
if !ok {
r, err := regexp.Compile(arg[0].Text())
if err != nil {
ctx.ResultError(err)
return // notest
}
re = r
}
ctx.SetAuxData(0, re)
}
ctx.ResultBool(re.Match(arg[1].RawText()))
}
func like(ctx sqlite3.Context, arg ...sqlite3.Value) {
escape := rune(-1)
if len(arg) == 3 {
var size int
b := arg[2].RawText()
escape, size = utf8.DecodeRune(b)
if size != len(b) {
ctx.ResultError(util.ErrorString("ESCAPE expression must be a single character"))
return
}
}
_ = arg[1] // bounds check
type likeData struct {
*regexp.Regexp
escape rune
}
re, ok := ctx.GetAuxData(0).(likeData)
if !ok || re.escape != escape {
re = likeData{
regexp.MustCompile(like2regex(arg[0].Text(), escape)),
escape,
}
ctx.SetAuxData(0, re)
}
ctx.ResultBool(re.Match(arg[1].RawText()))
}
func like2regex(pattern string, escape rune) string {
var re strings.Builder
start := 0
literal := false
re.Grow(len(pattern) + 10)
re.WriteString(`(?is)\A`) // case insensitive, . matches any character
for i, r := range pattern {
if start < 0 {
start = i
}
if literal {
literal = false
continue
}
var symbol string
switch r {
case '_':
symbol = `.`
case '%':
symbol = `.*`
case escape:
literal = true
default:
continue
}
re.WriteString(regexp.QuoteMeta(pattern[start:i]))
re.WriteString(symbol)
start = -1
}
if start >= 0 {
re.WriteString(regexp.QuoteMeta(pattern[start:]))
}
re.WriteString(`\z`)
return re.String()
}
// Package uuid provides functions to generate RFC 4122 UUIDs.
//
// https://sqlite.org/src/file/ext/misc/uuid.c
package uuid
import (
"bytes"
"errors"
"fmt"
"time"
"github.com/google/uuid"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Register registers the SQL functions:
//
// - uuid([ version [, domain/namespace, [ id/data ]]]):
// to generate a UUID as a string
// - uuid_str(u):
// to convert a UUID into a well-formed UUID string
// - uuid_blob(u):
// to convert a UUID into a 16-byte blob
// - uuid_extract_version(u):
// to extract the version of a RFC 4122 UUID
// - uuid_extract_timestamp(u):
// to extract the timestamp of a version 1/2/6/7 UUID
// - gen_random_uuid(u):
// to generate a version 4 (random) UUID
func Register(db *sqlite3.Conn) error {
const flags = sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
return errors.Join(
db.CreateFunction("uuid", 0, sqlite3.INNOCUOUS, generate),
db.CreateFunction("uuid", 1, sqlite3.INNOCUOUS, generate),
db.CreateFunction("uuid", 2, sqlite3.INNOCUOUS, generate),
db.CreateFunction("uuid", 3, sqlite3.INNOCUOUS, generate),
db.CreateFunction("uuid_str", 1, flags, toString),
db.CreateFunction("uuid_blob", 1, flags, toBlob),
db.CreateFunction("uuid_extract_version", 1, flags, version),
db.CreateFunction("uuid_extract_timestamp", 1, flags, timestamp),
db.CreateFunction("gen_random_uuid", 0, sqlite3.INNOCUOUS, generate))
}
func generate(ctx sqlite3.Context, arg ...sqlite3.Value) {
var (
ver int
err error
u uuid.UUID
)
if len(arg) > 0 {
ver = arg[0].Int()
} else {
ver = 4
}
switch ver {
case 1:
u, err = uuid.NewUUID()
case 4:
u, err = uuid.NewRandom()
case 6:
u, err = uuid.NewV6()
case 7:
u, err = uuid.NewV7()
case 2:
var domain uuid.Domain
if len(arg) > 1 {
domain = uuid.Domain(arg[1].Int64())
if domain == 0 {
if txt := arg[1].RawText(); len(txt) > 0 {
switch txt[0] | 0x20 { // to lower
case 'g': // group
domain = uuid.Group
case 'o': // org
domain = uuid.Org
}
}
}
}
switch {
case len(arg) > 2:
u, err = uuid.NewDCESecurity(domain, uint32(arg[2].Int64()))
case domain == uuid.Person:
u, err = uuid.NewDCEPerson()
case domain == uuid.Group:
u, err = uuid.NewDCEGroup()
default:
err = util.ErrorString("missing id")
}
case 3, 5:
if len(arg) < 2 {
err = util.ErrorString("missing data")
break
}
ns, err := fromValue(arg[1])
if err != nil {
space := arg[1].RawText()
switch {
case bytes.EqualFold(space, []byte("url")):
ns = uuid.NameSpaceURL
case bytes.EqualFold(space, []byte("oid")):
ns = uuid.NameSpaceOID
case bytes.EqualFold(space, []byte("dns")):
ns = uuid.NameSpaceDNS
case bytes.EqualFold(space, []byte("fqdn")):
ns = uuid.NameSpaceDNS
case bytes.EqualFold(space, []byte("x500")):
ns = uuid.NameSpaceX500
default:
ctx.ResultError(err)
return // notest
}
}
if ver == 3 {
u = uuid.NewMD5(ns, arg[2].RawBlob())
} else {
u = uuid.NewSHA1(ns, arg[2].RawBlob())
}
default:
err = fmt.Errorf("invalid version: %d", ver)
}
if err != nil {
ctx.ResultError(fmt.Errorf("uuid: %w", err)) // notest
} else {
ctx.ResultText(u.String())
}
}
func fromValue(arg sqlite3.Value) (u uuid.UUID, err error) {
switch t := arg.Type(); t {
case sqlite3.TEXT:
u, err = uuid.ParseBytes(arg.RawText())
if err != nil {
err = fmt.Errorf("uuid: %w", err)
}
case sqlite3.BLOB:
blob := arg.RawBlob()
if len := len(blob); len != 16 {
err = fmt.Errorf("uuid: invalid BLOB length: %d", len)
} else {
copy(u[:], blob)
}
default:
err = fmt.Errorf("uuid: invalid type: %v", t)
}
return u, err
}
func toBlob(ctx sqlite3.Context, arg ...sqlite3.Value) {
u, err := fromValue(arg[0])
if err != nil {
ctx.ResultError(err) // notest
} else {
ctx.ResultBlob(u[:])
}
}
func toString(ctx sqlite3.Context, arg ...sqlite3.Value) {
u, err := fromValue(arg[0])
if err != nil {
ctx.ResultError(err) // notest
} else {
ctx.ResultText(u.String())
}
}
func version(ctx sqlite3.Context, arg ...sqlite3.Value) {
u, err := fromValue(arg[0])
if err != nil {
ctx.ResultError(err)
return // notest
}
if u.Variant() == uuid.RFC4122 {
ctx.ResultInt64(int64(u.Version()))
}
}
func timestamp(ctx sqlite3.Context, arg ...sqlite3.Value) {
u, err := fromValue(arg[0])
if err != nil {
ctx.ResultError(err)
return // notest
}
if u.Variant() == uuid.RFC4122 {
switch u.Version() {
case 1, 2, 6, 7:
ctx.ResultTime(
time.Unix(u.Time().UnixTime()),
sqlite3.TimeFormatDefault)
}
}
}
// Package zorder provides functions for z-order transformations.
//
// https://sqlite.org/src/doc/tip/ext/misc/zorder.c
package zorder
import (
"errors"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Register registers the zorder and unzorder SQL functions.
func Register(db *sqlite3.Conn) error {
const flags = sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
return errors.Join(
db.CreateFunction("zorder", -1, flags, zorder),
db.CreateFunction("unzorder", 3, flags, unzorder))
}
func zorder(ctx sqlite3.Context, arg ...sqlite3.Value) {
var x [63]int64
if len(arg) > len(x) {
ctx.ResultError(util.ErrorString("zorder: too many parameters"))
return
}
for i := range arg {
x[i] = arg[i].Int64()
}
var z int64
if len(arg) > 0 {
for i := range x {
j := i % len(arg)
z |= (x[j] & 1) << i
x[j] >>= 1
}
}
for i := range arg {
if x[i] != 0 {
ctx.ResultError(util.ErrorString("zorder: parameter too large"))
return
}
}
ctx.ResultInt64(z)
}
func unzorder(ctx sqlite3.Context, arg ...sqlite3.Value) {
i := arg[2].Int64()
n := arg[1].Int64()
z := arg[0].Int64()
var k int
var x int64
for j := i; j < 63; j += n {
x |= ((z >> j) & 1) << k
k++
}
ctx.ResultInt64(x)
}
package sqlite3
import (
"context"
"io"
"iter"
"sync"
"sync/atomic"
"github.com/tetratelabs/wazero/api"
"github.com/ncruces/go-sqlite3/internal/util"
)
// CollationNeeded registers a callback to be invoked
// whenever an unknown collation sequence is required.
//
// https://sqlite.org/c3ref/collation_needed.html
func (c *Conn) CollationNeeded(cb func(db *Conn, name string)) error {
var enable int32
if cb != nil {
enable = 1
}
rc := res_t(c.call("sqlite3_collation_needed_go", stk_t(c.handle), stk_t(enable)))
if err := c.error(rc); err != nil {
return err
}
c.collation = cb
return nil
}
// AnyCollationNeeded uses [Conn.CollationNeeded] to register
// a fake collating function for any unknown collating sequence.
// The fake collating function works like BINARY.
//
// This can be used to load schemas that contain
// one or more unknown collating sequences.
func (c Conn) AnyCollationNeeded() error {
rc := res_t(c.call("sqlite3_anycollseq_init", stk_t(c.handle), 0, 0))
if err := c.error(rc); err != nil {
return err
}
c.collation = nil
return nil
}
// CreateCollation defines a new collating sequence.
//
// https://sqlite.org/c3ref/create_collation.html
func (c *Conn) CreateCollation(name string, fn CollatingFunction) error {
var funcPtr ptr_t
defer c.arena.mark()()
namePtr := c.arena.string(name)
if fn != nil {
funcPtr = util.AddHandle(c.ctx, fn)
}
rc := res_t(c.call("sqlite3_create_collation_go",
stk_t(c.handle), stk_t(namePtr), stk_t(funcPtr)))
return c.error(rc)
}
// Collating function is the type of a collation callback.
// Implementations must not retain a or b.
type CollatingFunction func(a, b []byte) int
// CreateFunction defines a new scalar SQL function.
//
// https://sqlite.org/c3ref/create_function.html
func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn ScalarFunction) error {
var funcPtr ptr_t
defer c.arena.mark()()
namePtr := c.arena.string(name)
if fn != nil {
funcPtr = util.AddHandle(c.ctx, fn)
}
rc := res_t(c.call("sqlite3_create_function_go",
stk_t(c.handle), stk_t(namePtr), stk_t(nArg),
stk_t(flag), stk_t(funcPtr)))
return c.error(rc)
}
// ScalarFunction is the type of a scalar SQL function.
// Implementations must not retain arg.
type ScalarFunction func(ctx Context, arg ...Value)
// CreateAggregateFunction defines a new aggregate SQL function.
//
// https://sqlite.org/c3ref/create_function.html
func (c *Conn) CreateAggregateFunction(name string, nArg int, flag FunctionFlag, fn AggregateSeqFunction) error {
var funcPtr ptr_t
defer c.arena.mark()()
namePtr := c.arena.string(name)
if fn != nil {
funcPtr = util.AddHandle(c.ctx, AggregateConstructor(func() AggregateFunction {
var a aggregateFunc
coro := func(yieldCoro func(struct{}) bool) {
seq := func(yieldSeq func([]Value) bool) {
for yieldSeq(a.arg) {
if !yieldCoro(struct{}{}) {
break
}
}
}
fn(&a.ctx, seq)
}
a.next, a.stop = iter.Pull(coro)
return &a
}))
}
rc := res_t(c.call("sqlite3_create_aggregate_function_go",
stk_t(c.handle), stk_t(namePtr), stk_t(nArg),
stk_t(flag), stk_t(funcPtr)))
return c.error(rc)
}
// AggregateSeqFunction is the type of an aggregate SQL function.
// Implementations must not retain the slices yielded by seq.
type AggregateSeqFunction func(ctx *Context, seq iter.Seq[[]Value])
// CreateWindowFunction defines a new aggregate or aggregate window SQL function.
// If fn returns a [WindowFunction], an aggregate window function is created.
// If fn returns an [io.Closer], it will be called to free resources.
//
// https://sqlite.org/c3ref/create_function.html
func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn AggregateConstructor) error {
var funcPtr ptr_t
defer c.arena.mark()()
namePtr := c.arena.string(name)
if fn != nil {
funcPtr = util.AddHandle(c.ctx, AggregateConstructor(func() AggregateFunction {
agg := fn()
if win, ok := agg.(WindowFunction); ok {
return win
}
return windowFunc{agg, name}
}))
}
rc := res_t(c.call("sqlite3_create_window_function_go",
stk_t(c.handle), stk_t(namePtr), stk_t(nArg),
stk_t(flag), stk_t(funcPtr)))
return c.error(rc)
}
// AggregateConstructor is a an [AggregateFunction] constructor.
type AggregateConstructor func() AggregateFunction
// AggregateFunction is the interface an aggregate function should implement.
//
// https://sqlite.org/appfunc.html
type AggregateFunction interface {
// Step is invoked to add a row to the current window.
// The function arguments, if any, corresponding to the row being added, are passed to Step.
// Implementations must not retain arg.
Step(ctx Context, arg ...Value)
// Value is invoked to return the current (or final) value of the aggregate.
Value(ctx Context)
}
// WindowFunction is the interface an aggregate window function should implement.
//
// https://sqlite.org/windowfunctions.html
type WindowFunction interface {
AggregateFunction
// Inverse is invoked to remove the oldest presently aggregated result of Step from the current window.
// The function arguments, if any, are those passed to Step for the row being removed.
// Implementations must not retain arg.
Inverse(ctx Context, arg ...Value)
}
// OverloadFunction overloads a function for a virtual table.
//
// https://sqlite.org/c3ref/overload_function.html
func (c *Conn) OverloadFunction(name string, nArg int) error {
defer c.arena.mark()()
namePtr := c.arena.string(name)
rc := res_t(c.call("sqlite3_overload_function",
stk_t(c.handle), stk_t(namePtr), stk_t(nArg)))
return c.error(rc)
}
func destroyCallback(ctx context.Context, mod api.Module, pApp ptr_t) {
util.DelHandle(ctx, pApp)
}
func collationCallback(ctx context.Context, mod api.Module, pArg, pDB ptr_t, eTextRep uint32, zName ptr_t) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.collation != nil {
name := util.ReadString(mod, zName, _MAX_NAME)
c.collation(c, name)
}
}
func compareCallback(ctx context.Context, mod api.Module, pApp ptr_t, nKey1 int32, pKey1 ptr_t, nKey2 int32, pKey2 ptr_t) uint32 {
fn := util.GetHandle(ctx, pApp).(CollatingFunction)
return uint32(fn(util.View(mod, pKey1, int64(nKey1)), util.View(mod, pKey2, int64(nKey2))))
}
func funcCallback(ctx context.Context, mod api.Module, pCtx, pApp ptr_t, nArg int32, pArg ptr_t) {
db := ctx.Value(connKey{}).(*Conn)
args := callbackArgs(db, nArg, pArg)
defer returnArgs(args)
fn := util.GetHandle(db.ctx, pApp).(ScalarFunction)
fn(Context{db, pCtx}, *args...)
}
func stepCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp ptr_t, nArg int32, pArg ptr_t) {
db := ctx.Value(connKey{}).(*Conn)
args := callbackArgs(db, nArg, pArg)
defer returnArgs(args)
fn, _ := callbackAggregate(db, pAgg, pApp)
fn.Step(Context{db, pCtx}, *args...)
}
func valueCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp ptr_t, final int32) {
db := ctx.Value(connKey{}).(*Conn)
fn, handle := callbackAggregate(db, pAgg, pApp)
fn.Value(Context{db, pCtx})
// Cleanup.
if final != 0 {
var err error
if handle != 0 {
err = util.DelHandle(ctx, handle)
} else if c, ok := fn.(io.Closer); ok {
err = c.Close()
}
if err != nil {
Context{db, pCtx}.ResultError(err)
return // notest
}
}
}
func inverseCallback(ctx context.Context, mod api.Module, pCtx, pAgg ptr_t, nArg int32, pArg ptr_t) {
db := ctx.Value(connKey{}).(*Conn)
args := callbackArgs(db, nArg, pArg)
defer returnArgs(args)
fn := util.GetHandle(db.ctx, pAgg).(WindowFunction)
fn.Inverse(Context{db, pCtx}, *args...)
}
func callbackAggregate(db *Conn, pAgg, pApp ptr_t) (AggregateFunction, ptr_t) {
if pApp == 0 {
handle := util.Read32[ptr_t](db.mod, pAgg)
return util.GetHandle(db.ctx, handle).(AggregateFunction), handle
}
// We need to create the aggregate.
fn := util.GetHandle(db.ctx, pApp).(AggregateConstructor)()
if pAgg != 0 {
handle := util.AddHandle(db.ctx, fn)
util.Write32(db.mod, pAgg, handle)
return fn, handle
}
return fn, 0
}
var (
valueArgsPool sync.Pool
valueArgsLen atomic.Int32
)
func callbackArgs(db *Conn, nArg int32, pArg ptr_t) *[]Value {
arg, ok := valueArgsPool.Get().(*[]Value)
if !ok || cap(*arg) < int(nArg) {
max := valueArgsLen.Or(nArg) | nArg
lst := make([]Value, max)
arg = &lst
}
lst := (*arg)[:nArg]
for i := range lst {
lst[i] = Value{
c: db,
handle: util.Read32[ptr_t](db.mod, pArg+ptr_t(i)*ptrlen),
}
}
*arg = lst
return arg
}
func returnArgs(p *[]Value) {
valueArgsPool.Put(p)
}
type aggregateFunc struct {
next func() (struct{}, bool)
stop func()
ctx Context
arg []Value
}
func (a *aggregateFunc) Step(ctx Context, arg ...Value) {
a.ctx = ctx
a.arg = append(a.arg[:0], arg...)
if _, more := a.next(); !more {
a.stop()
}
}
func (a *aggregateFunc) Value(ctx Context) {
a.ctx = ctx
a.stop()
}
func (a *aggregateFunc) Close() error {
a.stop()
return nil
}
type windowFunc struct {
AggregateFunction
name string
}
func (w windowFunc) Inverse(ctx Context, arg ...Value) {
// Implementing inverse allows certain queries that don't really need it to succeed.
ctx.ResultError(util.ErrorString(w.name + ": may not be used as a window function"))
}
//go:build unix
package alloc
import (
"math"
"github.com/tetratelabs/wazero/experimental"
"golang.org/x/sys/unix"
)
func NewMemory(_, max uint64) experimental.LinearMemory {
// Round up to the page size.
rnd := uint64(unix.Getpagesize() - 1)
max = (max + rnd) &^ rnd
if max > math.MaxInt {
// This ensures int(max) overflows to a negative value,
// and unix.Mmap returns EINVAL.
max = math.MaxUint64
}
// Reserve max bytes of address space, to ensure we won't need to move it.
// A protected, private, anonymous mapping should not commit memory.
b, err := unix.Mmap(-1, 0, int(max), unix.PROT_NONE, unix.MAP_PRIVATE|unix.MAP_ANON)
if err != nil {
panic(err)
}
return &mmappedMemory{buf: b[:0]}
}
// The slice covers the entire mmapped memory:
// - len(buf) is the already committed memory,
// - cap(buf) is the reserved address space.
type mmappedMemory struct {
buf []byte
}
func (m *mmappedMemory) Reallocate(size uint64) []byte {
com := uint64(len(m.buf))
res := uint64(cap(m.buf))
if com < size && size <= res {
// Round up to the page size.
rnd := uint64(unix.Getpagesize() - 1)
new := (size + rnd) &^ rnd
// Commit additional memory up to new bytes.
err := unix.Mprotect(m.buf[com:new], unix.PROT_READ|unix.PROT_WRITE)
if err != nil {
return nil
}
// Update committed memory.
m.buf = m.buf[:new]
}
// Limit returned capacity because bytes beyond
// len(m.buf) have not yet been committed.
return m.buf[:size:len(m.buf)]
}
func (m *mmappedMemory) Free() {
err := unix.Munmap(m.buf[:cap(m.buf)])
if err != nil {
panic(err)
}
m.buf = nil
}
package testcfg
import (
"os"
"path/filepath"
"github.com/tetratelabs/wazero"
"github.com/ncruces/go-sqlite3"
)
// notest
func init() {
sqlite3.RuntimeConfig = wazero.NewRuntimeConfig().WithMemoryLimitPages(512)
if os.Getenv("CI") != "" {
path := filepath.Join(os.TempDir(), "wazero")
if err := os.MkdirAll(path, 0777); err == nil {
if cache, err := wazero.NewCompilationCacheWithDir(path); err == nil {
sqlite3.RuntimeConfig = sqlite3.RuntimeConfig.
WithCompilationCache(cache)
}
}
}
}
package util
import (
"runtime"
"strconv"
)
type ErrorString string
func (e ErrorString) Error() string { return string(e) }
const (
NilErr = ErrorString("sqlite3: invalid memory address or null pointer dereference")
OOMErr = ErrorString("sqlite3: out of memory")
RangeErr = ErrorString("sqlite3: index out of range")
NoNulErr = ErrorString("sqlite3: missing NUL terminator")
NoBinaryErr = ErrorString("sqlite3: no SQLite binary embed/set/loaded")
BadBinaryErr = ErrorString("sqlite3: invalid SQLite binary embed/set/loaded")
TimeErr = ErrorString("sqlite3: invalid time value")
WhenceErr = ErrorString("sqlite3: invalid whence")
OffsetErr = ErrorString("sqlite3: invalid offset")
TailErr = ErrorString("sqlite3: multiple statements")
IsolationErr = ErrorString("sqlite3: unsupported isolation level")
ValueErr = ErrorString("sqlite3: unsupported value")
NoVFSErr = ErrorString("sqlite3: no such vfs: ")
)
func AssertErr() ErrorString {
msg := "sqlite3: assertion failed"
if _, file, line, ok := runtime.Caller(1); ok {
msg += " (" + file + ":" + strconv.Itoa(line) + ")"
}
return ErrorString(msg)
}
func ErrorCodeString(rc uint32) string {
switch rc {
case ABORT_ROLLBACK:
return "sqlite3: abort due to ROLLBACK"
case ROW:
return "sqlite3: another row available"
case DONE:
return "sqlite3: no more rows available"
}
switch rc & 0xff {
case OK:
return "sqlite3: not an error"
case ERROR:
return "sqlite3: SQL logic error"
case INTERNAL:
break
case PERM:
return "sqlite3: access permission denied"
case ABORT:
return "sqlite3: query aborted"
case BUSY:
return "sqlite3: database is locked"
case LOCKED:
return "sqlite3: database table is locked"
case NOMEM:
return "sqlite3: out of memory"
case READONLY:
return "sqlite3: attempt to write a readonly database"
case INTERRUPT:
return "sqlite3: interrupted"
case IOERR:
return "sqlite3: disk I/O error"
case CORRUPT:
return "sqlite3: database disk image is malformed"
case NOTFOUND:
return "sqlite3: unknown operation"
case FULL:
return "sqlite3: database or disk is full"
case CANTOPEN:
return "sqlite3: unable to open database file"
case PROTOCOL:
return "sqlite3: locking protocol"
case EMPTY:
break
case SCHEMA:
return "sqlite3: database schema has changed"
case TOOBIG:
return "sqlite3: string or blob too big"
case CONSTRAINT:
return "sqlite3: constraint failed"
case MISMATCH:
return "sqlite3: datatype mismatch"
case MISUSE:
return "sqlite3: bad parameter or other API misuse"
case NOLFS:
break
case AUTH:
return "sqlite3: authorization denied"
case FORMAT:
break
case RANGE:
return "sqlite3: column index out of range"
case NOTADB:
return "sqlite3: file is not a database"
case NOTICE:
return "sqlite3: notification message"
case WARNING:
return "sqlite3: warning message"
}
return "sqlite3: unknown error"
}
type ErrorJoiner []error
func (j *ErrorJoiner) Join(errs ...error) {
for _, err := range errs {
if err != nil {
*j = append(*j, err)
}
}
}
package util
import (
"context"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
)
type funcVI[T0 i32] func(context.Context, api.Module, T0)
func (fn funcVI[T0]) Call(ctx context.Context, mod api.Module, stack []uint64) {
fn(ctx, mod, T0(stack[0]))
}
func ExportFuncVI[T0 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0)) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcVI[T0](fn),
[]api.ValueType{api.ValueTypeI32}, nil).
Export(name)
}
type funcVII[T0, T1 i32] func(context.Context, api.Module, T0, T1)
func (fn funcVII[T0, T1]) Call(ctx context.Context, mod api.Module, stack []uint64) {
_ = stack[1] // prevent bounds check on every slice access
fn(ctx, mod, T0(stack[0]), T1(stack[1]))
}
func ExportFuncVII[T0, T1 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1)) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcVII[T0, T1](fn),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32}, nil).
Export(name)
}
type funcVIII[T0, T1, T2 i32] func(context.Context, api.Module, T0, T1, T2)
func (fn funcVIII[T0, T1, T2]) Call(ctx context.Context, mod api.Module, stack []uint64) {
_ = stack[2] // prevent bounds check on every slice access
fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]))
}
func ExportFuncVIII[T0, T1, T2 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2)) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcVIII[T0, T1, T2](fn),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, nil).
Export(name)
}
type funcVIIII[T0, T1, T2, T3 i32] func(context.Context, api.Module, T0, T1, T2, T3)
func (fn funcVIIII[T0, T1, T2, T3]) Call(ctx context.Context, mod api.Module, stack []uint64) {
_ = stack[3] // prevent bounds check on every slice access
fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]))
}
func ExportFuncVIIII[T0, T1, T2, T3 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2, T3)) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcVIIII[T0, T1, T2, T3](fn),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, nil).
Export(name)
}
type funcVIIIII[T0, T1, T2, T3, T4 i32] func(context.Context, api.Module, T0, T1, T2, T3, T4)
func (fn funcVIIIII[T0, T1, T2, T3, T4]) Call(ctx context.Context, mod api.Module, stack []uint64) {
_ = stack[4] // prevent bounds check on every slice access
fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]), T4(stack[4]))
}
func ExportFuncVIIIII[T0, T1, T2, T3, T4 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2, T3, T4)) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcVIIIII[T0, T1, T2, T3, T4](fn),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, nil).
Export(name)
}
type funcVIIIIJ[T0, T1, T2, T3 i32, T4 i64] func(context.Context, api.Module, T0, T1, T2, T3, T4)
func (fn funcVIIIIJ[T0, T1, T2, T3, T4]) Call(ctx context.Context, mod api.Module, stack []uint64) {
_ = stack[4] // prevent bounds check on every slice access
fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]), T4(stack[4]))
}
func ExportFuncVIIIIJ[T0, T1, T2, T3 i32, T4 i64](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2, T3, T4)) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcVIIIIJ[T0, T1, T2, T3, T4](fn),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI64}, nil).
Export(name)
}
type funcII[TR, T0 i32] func(context.Context, api.Module, T0) TR
func (fn funcII[TR, T0]) Call(ctx context.Context, mod api.Module, stack []uint64) {
stack[0] = uint64(fn(ctx, mod, T0(stack[0])))
}
func ExportFuncII[TR, T0 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0) TR) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcII[TR, T0](fn),
[]api.ValueType{api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}
type funcIII[TR, T0, T1 i32] func(context.Context, api.Module, T0, T1) TR
func (fn funcIII[TR, T0, T1]) Call(ctx context.Context, mod api.Module, stack []uint64) {
_ = stack[1] // prevent bounds check on every slice access
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1])))
}
func ExportFuncIII[TR, T0, T1 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1) TR) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcIII[TR, T0, T1](fn),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}
type funcIIII[TR, T0, T1, T2 i32] func(context.Context, api.Module, T0, T1, T2) TR
func (fn funcIIII[TR, T0, T1, T2]) Call(ctx context.Context, mod api.Module, stack []uint64) {
_ = stack[2] // prevent bounds check on every slice access
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2])))
}
func ExportFuncIIII[TR, T0, T1, T2 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2) TR) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcIIII[TR, T0, T1, T2](fn),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}
type funcIIIII[TR, T0, T1, T2, T3 i32] func(context.Context, api.Module, T0, T1, T2, T3) TR
func (fn funcIIIII[TR, T0, T1, T2, T3]) Call(ctx context.Context, mod api.Module, stack []uint64) {
_ = stack[3] // prevent bounds check on every slice access
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3])))
}
func ExportFuncIIIII[TR, T0, T1, T2, T3 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2, T3) TR) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcIIIII[TR, T0, T1, T2, T3](fn),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}
type funcIIIIII[TR, T0, T1, T2, T3, T4 i32] func(context.Context, api.Module, T0, T1, T2, T3, T4) TR
func (fn funcIIIIII[TR, T0, T1, T2, T3, T4]) Call(ctx context.Context, mod api.Module, stack []uint64) {
_ = stack[4] // prevent bounds check on every slice access
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]), T4(stack[4])))
}
func ExportFuncIIIIII[TR, T0, T1, T2, T3, T4 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2, T3, T4) TR) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcIIIIII[TR, T0, T1, T2, T3, T4](fn),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}
type funcIIIIIII[TR, T0, T1, T2, T3, T4, T5 i32] func(context.Context, api.Module, T0, T1, T2, T3, T4, T5) TR
func (fn funcIIIIIII[TR, T0, T1, T2, T3, T4, T5]) Call(ctx context.Context, mod api.Module, stack []uint64) {
_ = stack[5] // prevent bounds check on every slice access
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]), T4(stack[4]), T5(stack[5])))
}
func ExportFuncIIIIIII[TR, T0, T1, T2, T3, T4, T5 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2, T3, T4, T5) TR) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcIIIIIII[TR, T0, T1, T2, T3, T4, T5](fn),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}
type funcIIIIJ[TR, T0, T1, T2 i32, T3 i64] func(context.Context, api.Module, T0, T1, T2, T3) TR
func (fn funcIIIIJ[TR, T0, T1, T2, T3]) Call(ctx context.Context, mod api.Module, stack []uint64) {
_ = stack[3] // prevent bounds check on every slice access
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3])))
}
func ExportFuncIIIIJ[TR, T0, T1, T2 i32, T3 i64](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2, T3) TR) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcIIIIJ[TR, T0, T1, T2, T3](fn),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI64}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}
type funcIIJ[TR, T0 i32, T1 i64] func(context.Context, api.Module, T0, T1) TR
func (fn funcIIJ[TR, T0, T1]) Call(ctx context.Context, mod api.Module, stack []uint64) {
_ = stack[1] // prevent bounds check on every slice access
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1])))
}
func ExportFuncIIJ[TR, T0 i32, T1 i64](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1) TR) {
mod.NewFunctionBuilder().
WithGoModuleFunction(funcIIJ[TR, T0, T1](fn),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI64}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}
package util
import (
"context"
"io"
)
type handleState struct {
handles []any
holes int
}
func (s *handleState) CloseNotify(ctx context.Context, exitCode uint32) {
for _, h := range s.handles {
if c, ok := h.(io.Closer); ok {
c.Close()
}
}
s.handles = nil
s.holes = 0
}
func GetHandle(ctx context.Context, id Ptr_t) any {
if id == 0 {
return nil
}
s := ctx.Value(moduleKey{}).(*moduleState)
return s.handles[^id]
}
func DelHandle(ctx context.Context, id Ptr_t) error {
if id == 0 {
return nil
}
s := ctx.Value(moduleKey{}).(*moduleState)
a := s.handles[^id]
s.handles[^id] = nil
if l := Ptr_t(len(s.handles)); l == ^id {
s.handles = s.handles[:l-1]
} else {
s.holes++
}
if c, ok := a.(io.Closer); ok {
return c.Close()
}
return nil
}
func AddHandle(ctx context.Context, a any) Ptr_t {
if a == nil {
panic(NilErr)
}
s := ctx.Value(moduleKey{}).(*moduleState)
// Find an empty slot.
if s.holes > cap(s.handles)-len(s.handles) {
for id, h := range s.handles {
if h == nil {
s.holes--
s.handles[id] = a
return ^Ptr_t(id)
}
}
}
// Add a new slot.
s.handles = append(s.handles, a)
return -Ptr_t(len(s.handles))
}
package util
import (
"encoding/json"
"math"
"strconv"
"time"
"unsafe"
)
type JSON struct{ Value any }
func (j JSON) Scan(value any) error {
var buf []byte
switch v := value.(type) {
case []byte:
buf = v
case string:
buf = unsafe.Slice(unsafe.StringData(v), len(v))
case int64:
buf = strconv.AppendInt(nil, v, 10)
case float64:
buf = AppendNumber(nil, v)
case time.Time:
buf = append(buf, '"')
buf = v.AppendFormat(buf, time.RFC3339Nano)
buf = append(buf, '"')
case nil:
buf = []byte("null")
default:
panic(AssertErr())
}
return json.Unmarshal(buf, j.Value)
}
func AppendNumber(dst []byte, f float64) []byte {
switch {
case math.IsNaN(f):
dst = append(dst, "null"...)
case math.IsInf(f, 1):
dst = append(dst, "9.0e999"...)
case math.IsInf(f, -1):
dst = append(dst, "-9.0e999"...)
default:
return strconv.AppendFloat(dst, f, 'g', -1, 64)
}
return dst
}
package util
import "math"
func abs(n int) int {
if n < 0 {
return -n
}
return n
}
func GCD(m, n int) int {
for n != 0 {
m, n = n, m%n
}
return abs(m)
}
func LCM(m, n int) int {
if n == 0 {
return 0
}
return abs(n) * (abs(m) / GCD(m, n))
}
// https://developer.nvidia.com/blog/lerp-faster-cuda/
func Lerp(v0, v1, t float64) float64 {
return math.FMA(t, v1, math.FMA(-t, v0, v0))
}
package util
import (
"bytes"
"math"
"github.com/tetratelabs/wazero/api"
)
const (
PtrLen = 4
IntLen = 4
)
type (
i8 interface{ ~int8 | ~uint8 }
i32 interface{ ~int32 | ~uint32 }
i64 interface{ ~int64 | ~uint64 }
Stk_t = uint64
Ptr_t uint32
Res_t int32
)
func View(mod api.Module, ptr Ptr_t, size int64) []byte {
if ptr == 0 {
panic(NilErr)
}
if uint64(size) > math.MaxUint32 {
panic(RangeErr)
}
buf, ok := mod.Memory().Read(uint32(ptr), uint32(size))
if !ok {
panic(RangeErr)
}
return buf
}
func Read[T i8](mod api.Module, ptr Ptr_t) T {
if ptr == 0 {
panic(NilErr)
}
v, ok := mod.Memory().ReadByte(uint32(ptr))
if !ok {
panic(RangeErr)
}
return T(v)
}
func Write[T i8](mod api.Module, ptr Ptr_t, v T) {
if ptr == 0 {
panic(NilErr)
}
ok := mod.Memory().WriteByte(uint32(ptr), uint8(v))
if !ok {
panic(RangeErr)
}
}
func Read32[T i32](mod api.Module, ptr Ptr_t) T {
if ptr == 0 {
panic(NilErr)
}
v, ok := mod.Memory().ReadUint32Le(uint32(ptr))
if !ok {
panic(RangeErr)
}
return T(v)
}
func Write32[T i32](mod api.Module, ptr Ptr_t, v T) {
if ptr == 0 {
panic(NilErr)
}
ok := mod.Memory().WriteUint32Le(uint32(ptr), uint32(v))
if !ok {
panic(RangeErr)
}
}
func Read64[T i64](mod api.Module, ptr Ptr_t) T {
if ptr == 0 {
panic(NilErr)
}
v, ok := mod.Memory().ReadUint64Le(uint32(ptr))
if !ok {
panic(RangeErr)
}
return T(v)
}
func Write64[T i64](mod api.Module, ptr Ptr_t, v T) {
if ptr == 0 {
panic(NilErr)
}
ok := mod.Memory().WriteUint64Le(uint32(ptr), uint64(v))
if !ok {
panic(RangeErr)
}
}
func ReadFloat64(mod api.Module, ptr Ptr_t) float64 {
return math.Float64frombits(Read64[uint64](mod, ptr))
}
func WriteFloat64(mod api.Module, ptr Ptr_t, v float64) {
Write64(mod, ptr, math.Float64bits(v))
}
func ReadBool(mod api.Module, ptr Ptr_t) bool {
return Read32[int32](mod, ptr) != 0
}
func WriteBool(mod api.Module, ptr Ptr_t, v bool) {
var i int32
if v {
i = 1
}
Write32(mod, ptr, i)
}
func ReadString(mod api.Module, ptr Ptr_t, maxlen int64) string {
if ptr == 0 {
panic(NilErr)
}
if maxlen <= 0 {
return ""
}
mem := mod.Memory()
maxlen = min(maxlen, math.MaxInt32-1) + 1
buf, ok := mem.Read(uint32(ptr), uint32(maxlen))
if !ok {
buf, ok = mem.Read(uint32(ptr), mem.Size()-uint32(ptr))
if !ok {
panic(RangeErr)
}
}
if i := bytes.IndexByte(buf, 0); i >= 0 {
return string(buf[:i])
}
panic(NoNulErr)
}
func WriteBytes(mod api.Module, ptr Ptr_t, b []byte) {
buf := View(mod, ptr, int64(len(b)))
copy(buf, b)
}
func WriteString(mod api.Module, ptr Ptr_t, s string) {
buf := View(mod, ptr, int64(len(s))+1)
buf[len(s)] = 0
copy(buf, s)
}
//go:build unix
package util
import (
"context"
"os"
"unsafe"
"github.com/tetratelabs/wazero/api"
"golang.org/x/sys/unix"
)
type mmapState struct {
regions []*MappedRegion
}
func (s *mmapState) new(ctx context.Context, mod api.Module, size int32) *MappedRegion {
// Find unused region.
for _, r := range s.regions {
if !r.used && r.size == size {
return r
}
}
// Allocate page aligned memmory.
alloc := mod.ExportedFunction("aligned_alloc")
stack := [...]Stk_t{
Stk_t(unix.Getpagesize()),
Stk_t(size),
}
if err := alloc.CallWithStack(ctx, stack[:]); err != nil {
panic(err)
}
if stack[0] == 0 {
panic(OOMErr)
}
// Save the newly allocated region.
ptr := Ptr_t(stack[0])
buf := View(mod, ptr, int64(size))
ret := &MappedRegion{
Ptr: ptr,
size: size,
addr: unsafe.Pointer(&buf[0]),
}
s.regions = append(s.regions, ret)
return ret
}
type MappedRegion struct {
addr unsafe.Pointer
Ptr Ptr_t
size int32
used bool
}
func MapRegion(ctx context.Context, mod api.Module, f *os.File, offset int64, size int32, readOnly bool) (*MappedRegion, error) {
s := ctx.Value(moduleKey{}).(*moduleState)
r := s.new(ctx, mod, size)
err := r.mmap(f, offset, readOnly)
if err != nil {
return nil, err
}
return r, nil
}
func (r *MappedRegion) Unmap() error {
// We can't munmap the region, otherwise it could be remaped.
// Instead, convert it to a protected, private, anonymous mapping.
// If successful, it can be reused for a subsequent mmap.
_, err := unix.MmapPtr(-1, 0, r.addr, uintptr(r.size),
unix.PROT_NONE, unix.MAP_PRIVATE|unix.MAP_FIXED|unix.MAP_ANON)
r.used = err != nil
return err
}
func (r *MappedRegion) mmap(f *os.File, offset int64, readOnly bool) error {
prot := unix.PROT_READ
if !readOnly {
prot |= unix.PROT_WRITE
}
_, err := unix.MmapPtr(int(f.Fd()), offset, r.addr, uintptr(r.size),
prot, unix.MAP_SHARED|unix.MAP_FIXED)
r.used = err == nil
return err
}
package util
import (
"context"
"github.com/tetratelabs/wazero/experimental"
"github.com/ncruces/go-sqlite3/internal/alloc"
)
type ConnKey struct{}
type moduleKey struct{}
type moduleState struct {
mmapState
handleState
}
func NewContext(ctx context.Context) context.Context {
state := new(moduleState)
ctx = experimental.WithMemoryAllocator(ctx, experimental.MemoryAllocatorFunc(alloc.NewMemory))
ctx = experimental.WithCloseNotifier(ctx, state)
ctx = context.WithValue(ctx, moduleKey{}, state)
return ctx
}
package util
type Pointer[T any] struct{ Value T }
func (p Pointer[T]) unwrap() any { return p.Value }
type PointerUnwrap interface{ unwrap() any }
func UnwrapPointer(p PointerUnwrap) any {
return p.unwrap()
}
package util
import "reflect"
func ReflectType(v reflect.Value) reflect.Type {
if v.Kind() != reflect.Invalid {
return v.Type()
}
return nil
}
package util
type Set[E comparable] map[E]struct{}
func (s Set[E]) Add(v E) {
s[v] = struct{}{}
}
func (s Set[E]) Contains(v E) bool {
_, ok := s[v]
return ok
}
package sqlite3
import "github.com/ncruces/go-sqlite3/internal/util"
// JSON returns a value that can be used as an argument to
// [database/sql.DB.Exec], [database/sql.Row.Scan] and similar methods to
// store value as JSON, or decode JSON into value.
// JSON should NOT be used with [Stmt.BindJSON], [Stmt.ColumnJSON],
// [Value.JSON], or [Context.ResultJSON].
func JSON(value any) any {
return util.JSON{Value: value}
}
package sqlite3
import "github.com/ncruces/go-sqlite3/internal/util"
// Pointer returns a pointer to a value that can be used as an argument to
// [database/sql.DB.Exec] and similar methods.
// Pointer should NOT be used with [Stmt.BindPointer],
// [Value.Pointer], or [Context.ResultPointer].
//
// https://sqlite.org/bindptr.html
func Pointer[T any](value T) any {
return util.Pointer[T]{Value: value}
}
package sqlite3
import (
"bytes"
"math"
"reflect"
"strconv"
"strings"
"time"
"unsafe"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Quote escapes and quotes a value
// making it safe to embed in SQL text.
// Strings with embedded NUL characters are truncated.
//
// https://sqlite.org/lang_corefunc.html#quote
func Quote(value any) string {
switch v := value.(type) {
case nil:
return "NULL"
case bool:
if v {
return "1"
} else {
return "0"
}
case int:
return strconv.Itoa(v)
case int64:
return strconv.FormatInt(v, 10)
case float64:
switch {
case math.IsNaN(v):
return "NULL"
case math.IsInf(v, 1):
return "9.0e999"
case math.IsInf(v, -1):
return "-9.0e999"
}
return strconv.FormatFloat(v, 'g', -1, 64)
case time.Time:
return "'" + v.Format(time.RFC3339Nano) + "'"
case string:
if i := strings.IndexByte(v, 0); i >= 0 {
v = v[:i]
}
buf := make([]byte, 2+len(v)+strings.Count(v, "'"))
buf[0] = '\''
i := 1
for _, b := range []byte(v) {
if b == '\'' {
buf[i] = b
i += 1
}
buf[i] = b
i += 1
}
buf[len(buf)-1] = '\''
return unsafe.String(&buf[0], len(buf))
case []byte:
buf := make([]byte, 3+2*len(v))
buf[1] = '\''
buf[0] = 'x'
i := 2
for _, b := range v {
const hex = "0123456789ABCDEF"
buf[i+0] = hex[b/16]
buf[i+1] = hex[b%16]
i += 2
}
buf[len(buf)-1] = '\''
return unsafe.String(&buf[0], len(buf))
case ZeroBlob:
buf := bytes.Repeat([]byte("0"), int(3+2*int64(v)))
buf[1] = '\''
buf[0] = 'x'
buf[len(buf)-1] = '\''
return unsafe.String(&buf[0], len(buf))
}
v := reflect.ValueOf(value)
k := v.Kind()
if k == reflect.Interface || k == reflect.Pointer {
if v.IsNil() {
return "NULL"
}
v = v.Elem()
k = v.Kind()
}
switch {
case v.CanInt():
return strconv.FormatInt(v.Int(), 10)
case v.CanUint():
return strconv.FormatUint(v.Uint(), 10)
case v.CanFloat():
return Quote(v.Float())
case k == reflect.Bool:
return Quote(v.Bool())
case k == reflect.String:
return Quote(v.String())
case (k == reflect.Slice || k == reflect.Array && v.CanAddr()) &&
v.Type().Elem().Kind() == reflect.Uint8:
return Quote(v.Bytes())
}
panic(util.ValueErr)
}
// QuoteIdentifier escapes and quotes an identifier
// making it safe to embed in SQL text.
// Strings with embedded NUL characters panic.
func QuoteIdentifier(id string) string {
if strings.IndexByte(id, 0) >= 0 {
panic(util.ValueErr)
}
buf := make([]byte, 2+len(id)+strings.Count(id, `"`))
buf[0] = '"'
i := 1
for _, b := range []byte(id) {
if b == '"' {
buf[i] = b
i += 1
}
buf[i] = b
i += 1
}
buf[len(buf)-1] = '"'
return unsafe.String(&buf[0], len(buf))
}
package sqlite3
import "sync"
var (
// +checklocks:extRegistryMtx
extRegistry []func(*Conn) error
extRegistryMtx sync.RWMutex
)
// AutoExtension causes the entryPoint function to be invoked
// for each new database connection that is created.
//
// https://sqlite.org/c3ref/auto_extension.html
func AutoExtension(entryPoint func(*Conn) error) {
extRegistryMtx.Lock()
defer extRegistryMtx.Unlock()
extRegistry = append(extRegistry, entryPoint)
}
func initExtensions(c *Conn) error {
extRegistryMtx.RLock()
defer extRegistryMtx.RUnlock()
for _, f := range extRegistry {
if err := f(c); err != nil {
return err
}
}
return nil
}
// Package sqlite3 wraps the C SQLite API.
package sqlite3
import (
"context"
"math/bits"
"os"
"sync"
"unsafe"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/go-sqlite3/vfs"
)
// Configure SQLite Wasm.
//
// Importing package embed initializes [Binary]
// with an appropriate build of SQLite:
//
// import _ "github.com/ncruces/go-sqlite3/embed"
var (
Binary []byte // Wasm binary to load.
Path string // Path to load the binary from.
RuntimeConfig wazero.RuntimeConfig
)
// Initialize decodes and compiles the SQLite Wasm binary.
// This is called implicitly when the first connection is openned,
// but is potentially slow, so you may want to call it at a more convenient time.
func Initialize() error {
instance.once.Do(compileSQLite)
return instance.err
}
var instance struct {
runtime wazero.Runtime
compiled wazero.CompiledModule
err error
once sync.Once
}
func compileSQLite() {
ctx := context.Background()
cfg := RuntimeConfig
if cfg == nil {
cfg = wazero.NewRuntimeConfig()
if bits.UintSize < 64 {
cfg = cfg.WithMemoryLimitPages(512) // 32MB
} else {
cfg = cfg.WithMemoryLimitPages(4096) // 256MB
}
cfg = cfg.WithCoreFeatures(api.CoreFeaturesV2)
}
instance.runtime = wazero.NewRuntimeWithConfig(ctx, cfg)
env := instance.runtime.NewHostModuleBuilder("env")
env = vfs.ExportHostFunctions(env)
env = exportCallbacks(env)
_, instance.err = env.Instantiate(ctx)
if instance.err != nil {
return
}
bin := Binary
if bin == nil && Path != "" {
bin, instance.err = os.ReadFile(Path)
if instance.err != nil {
return
}
}
if bin == nil {
instance.err = util.NoBinaryErr
return
}
instance.compiled, instance.err = instance.runtime.CompileModule(ctx, bin)
}
type sqlite struct {
ctx context.Context
mod api.Module
funcs struct {
fn [32]api.Function
id [32]*byte
mask uint32
}
stack [9]stk_t
}
func instantiateSQLite() (sqlt *sqlite, err error) {
if err := Initialize(); err != nil {
return nil, err
}
sqlt = new(sqlite)
sqlt.ctx = util.NewContext(context.Background())
sqlt.mod, err = instance.runtime.InstantiateModule(sqlt.ctx,
instance.compiled, wazero.NewModuleConfig().WithName(""))
if err != nil {
return nil, err
}
if sqlt.getfn("sqlite3_progress_handler_go") == nil {
return nil, util.BadBinaryErr
}
return sqlt, nil
}
func (sqlt *sqlite) close() error {
return sqlt.mod.Close(sqlt.ctx)
}
func (sqlt *sqlite) error(rc res_t, handle ptr_t, sql ...string) error {
if rc == _OK {
return nil
}
if ErrorCode(rc) == NOMEM || xErrorCode(rc) == IOERR_NOMEM {
panic(util.OOMErr)
}
if handle != 0 {
var msg, query string
if ptr := ptr_t(sqlt.call("sqlite3_errmsg", stk_t(handle))); ptr != 0 {
msg = util.ReadString(sqlt.mod, ptr, _MAX_LENGTH)
switch {
case msg == "not an error":
msg = ""
case msg == util.ErrorCodeString(uint32(rc))[len("sqlite3: "):]:
msg = ""
}
}
if len(sql) != 0 {
if i := int32(sqlt.call("sqlite3_error_offset", stk_t(handle))); i != -1 {
query = sql[0][i:]
}
}
if msg != "" || query != "" {
return &Error{code: rc, msg: msg, sql: query}
}
}
return xErrorCode(rc)
}
func (sqlt *sqlite) getfn(name string) api.Function {
c := &sqlt.funcs
p := unsafe.StringData(name)
for i := range c.id {
if c.id[i] == p {
c.id[i] = nil
c.mask &^= uint32(1) << i
return c.fn[i]
}
}
return sqlt.mod.ExportedFunction(name)
}
func (sqlt *sqlite) putfn(name string, fn api.Function) {
c := &sqlt.funcs
p := unsafe.StringData(name)
i := bits.TrailingZeros32(^c.mask)
if i < 32 {
c.id[i] = p
c.fn[i] = fn
c.mask |= uint32(1) << i
} else {
c.id[0] = p
c.fn[0] = fn
c.mask = uint32(1)
}
}
func (sqlt *sqlite) call(name string, params ...stk_t) stk_t {
copy(sqlt.stack[:], params)
fn := sqlt.getfn(name)
err := fn.CallWithStack(sqlt.ctx, sqlt.stack[:])
if err != nil {
panic(err)
}
sqlt.putfn(name, fn)
return stk_t(sqlt.stack[0])
}
func (sqlt *sqlite) free(ptr ptr_t) {
if ptr == 0 {
return
}
sqlt.call("sqlite3_free", stk_t(ptr))
}
func (sqlt *sqlite) new(size int64) ptr_t {
ptr := ptr_t(sqlt.call("sqlite3_malloc64", stk_t(size)))
if ptr == 0 && size != 0 {
panic(util.OOMErr)
}
return ptr
}
func (sqlt *sqlite) realloc(ptr ptr_t, size int64) ptr_t {
ptr = ptr_t(sqlt.call("sqlite3_realloc64", stk_t(ptr), stk_t(size)))
if ptr == 0 && size != 0 {
panic(util.OOMErr)
}
return ptr
}
func (sqlt *sqlite) newBytes(b []byte) ptr_t {
if (*[0]byte)(b) == nil {
return 0
}
size := len(b)
if size == 0 {
size = 1
}
ptr := sqlt.new(int64(size))
util.WriteBytes(sqlt.mod, ptr, b)
return ptr
}
func (sqlt *sqlite) newString(s string) ptr_t {
ptr := sqlt.new(int64(len(s)) + 1)
util.WriteString(sqlt.mod, ptr, s)
return ptr
}
const arenaSize = 4096
func (sqlt *sqlite) newArena() arena {
return arena{
sqlt: sqlt,
base: sqlt.new(arenaSize),
}
}
type arena struct {
sqlt *sqlite
ptrs []ptr_t
base ptr_t
next int32
}
func (a *arena) free() {
if a.sqlt == nil {
return
}
for _, ptr := range a.ptrs {
a.sqlt.free(ptr)
}
a.sqlt.free(a.base)
a.sqlt = nil
}
func (a *arena) mark() (reset func()) {
ptrs := len(a.ptrs)
next := a.next
return func() {
rest := a.ptrs[ptrs:]
for _, ptr := range a.ptrs[:ptrs] {
a.sqlt.free(ptr)
}
a.ptrs = rest
a.next = next
}
}
func (a *arena) new(size int64) ptr_t {
// Align the next address, to 4 or 8 bytes.
if size&7 != 0 {
a.next = (a.next + 3) &^ 3
} else {
a.next = (a.next + 7) &^ 7
}
if size <= arenaSize-int64(a.next) {
ptr := a.base + ptr_t(a.next)
a.next += int32(size)
return ptr_t(ptr)
}
ptr := a.sqlt.new(size)
a.ptrs = append(a.ptrs, ptr)
return ptr_t(ptr)
}
func (a *arena) bytes(b []byte) ptr_t {
if (*[0]byte)(b) == nil {
return 0
}
ptr := a.new(int64(len(b)))
util.WriteBytes(a.sqlt.mod, ptr, b)
return ptr
}
func (a *arena) string(s string) ptr_t {
ptr := a.new(int64(len(s)) + 1)
util.WriteString(a.sqlt.mod, ptr, s)
return ptr
}
func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
util.ExportFuncII(env, "go_progress_handler", progressCallback)
util.ExportFuncIII(env, "go_busy_timeout", timeoutCallback)
util.ExportFuncIII(env, "go_busy_handler", busyCallback)
util.ExportFuncII(env, "go_commit_hook", commitCallback)
util.ExportFuncVI(env, "go_rollback_hook", rollbackCallback)
util.ExportFuncVIIIIJ(env, "go_update_hook", updateCallback)
util.ExportFuncIIIII(env, "go_wal_hook", walCallback)
util.ExportFuncIIIII(env, "go_trace", traceCallback)
util.ExportFuncIIIIII(env, "go_autovacuum_pages", autoVacuumCallback)
util.ExportFuncIIIIIII(env, "go_authorizer", authorizerCallback)
util.ExportFuncVIII(env, "go_log", logCallback)
util.ExportFuncVI(env, "go_destroy", destroyCallback)
util.ExportFuncVIIII(env, "go_func", funcCallback)
util.ExportFuncVIIIII(env, "go_step", stepCallback)
util.ExportFuncVIIII(env, "go_value", valueCallback)
util.ExportFuncVIIII(env, "go_inverse", inverseCallback)
util.ExportFuncVIIII(env, "go_collation_needed", collationCallback)
util.ExportFuncIIIIII(env, "go_compare", compareCallback)
util.ExportFuncIIIIII(env, "go_vtab_create", vtabModuleCallback(xCreate))
util.ExportFuncIIIIII(env, "go_vtab_connect", vtabModuleCallback(xConnect))
util.ExportFuncII(env, "go_vtab_disconnect", vtabDisconnectCallback)
util.ExportFuncII(env, "go_vtab_destroy", vtabDestroyCallback)
util.ExportFuncIII(env, "go_vtab_best_index", vtabBestIndexCallback)
util.ExportFuncIIIII(env, "go_vtab_update", vtabUpdateCallback)
util.ExportFuncIII(env, "go_vtab_rename", vtabRenameCallback)
util.ExportFuncIIIII(env, "go_vtab_find_function", vtabFindFuncCallback)
util.ExportFuncII(env, "go_vtab_begin", vtabBeginCallback)
util.ExportFuncII(env, "go_vtab_sync", vtabSyncCallback)
util.ExportFuncII(env, "go_vtab_commit", vtabCommitCallback)
util.ExportFuncII(env, "go_vtab_rollback", vtabRollbackCallback)
util.ExportFuncIII(env, "go_vtab_savepoint", vtabSavepointCallback)
util.ExportFuncIII(env, "go_vtab_release", vtabReleaseCallback)
util.ExportFuncIII(env, "go_vtab_rollback_to", vtabRollbackToCallback)
util.ExportFuncIIIIII(env, "go_vtab_integrity", vtabIntegrityCallback)
util.ExportFuncIII(env, "go_cur_open", cursorOpenCallback)
util.ExportFuncII(env, "go_cur_close", cursorCloseCallback)
util.ExportFuncIIIIII(env, "go_cur_filter", cursorFilterCallback)
util.ExportFuncII(env, "go_cur_next", cursorNextCallback)
util.ExportFuncII(env, "go_cur_eof", cursorEOFCallback)
util.ExportFuncIIII(env, "go_cur_column", cursorColumnCallback)
util.ExportFuncIII(env, "go_cur_rowid", cursorRowIDCallback)
return env
}
package sqlite3
import (
"encoding/json"
"math"
"strconv"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Stmt is a prepared statement object.
//
// https://sqlite.org/c3ref/stmt.html
type Stmt struct {
c *Conn
err error
sql string
handle ptr_t
}
// Close destroys the prepared statement object.
//
// It is safe to close a nil, zero or closed Stmt.
//
// https://sqlite.org/c3ref/finalize.html
func (s *Stmt) Close() error {
if s == nil || s.handle == 0 {
return nil
}
rc := res_t(s.c.call("sqlite3_finalize", stk_t(s.handle)))
stmts := s.c.stmts
for i := range stmts {
if s == stmts[i] {
l := len(stmts) - 1
stmts[i] = stmts[l]
stmts[l] = nil
s.c.stmts = stmts[:l]
break
}
}
s.handle = 0
return s.c.error(rc)
}
// Conn returns the database connection to which the prepared statement belongs.
//
// https://sqlite.org/c3ref/db_handle.html
func (s *Stmt) Conn() *Conn {
return s.c
}
// SQL returns the SQL text used to create the prepared statement.
//
// https://sqlite.org/c3ref/expanded_sql.html
func (s *Stmt) SQL() string {
return s.sql
}
// ExpandedSQL returns the SQL text of the prepared statement
// with bound parameters expanded.
//
// https://sqlite.org/c3ref/expanded_sql.html
func (s *Stmt) ExpandedSQL() string {
ptr := ptr_t(s.c.call("sqlite3_expanded_sql", stk_t(s.handle)))
sql := util.ReadString(s.c.mod, ptr, _MAX_SQL_LENGTH)
s.c.free(ptr)
return sql
}
// ReadOnly returns true if and only if the statement
// makes no direct changes to the content of the database file.
//
// https://sqlite.org/c3ref/stmt_readonly.html
func (s *Stmt) ReadOnly() bool {
b := int32(s.c.call("sqlite3_stmt_readonly", stk_t(s.handle)))
return b != 0
}
// Reset resets the prepared statement object.
//
// https://sqlite.org/c3ref/reset.html
func (s *Stmt) Reset() error {
rc := res_t(s.c.call("sqlite3_reset", stk_t(s.handle)))
s.err = nil
return s.c.error(rc)
}
// Busy determines if a prepared statement has been reset.
//
// https://sqlite.org/c3ref/stmt_busy.html
func (s *Stmt) Busy() bool {
rc := res_t(s.c.call("sqlite3_stmt_busy", stk_t(s.handle)))
return rc != 0
}
// Step evaluates the SQL statement.
// If the SQL statement being executed returns any data,
// then true is returned each time a new row of data is ready for processing by the caller.
// The values may be accessed using the Column access functions.
// Step is called again to retrieve the next row of data.
// If an error has occurred, Step returns false;
// call [Stmt.Err] or [Stmt.Reset] to get the error.
//
// https://sqlite.org/c3ref/step.html
func (s *Stmt) Step() bool {
if s.c.interrupt.Err() != nil {
s.err = INTERRUPT
return false
}
return s.step()
}
func (s *Stmt) step() bool {
rc := res_t(s.c.call("sqlite3_step", stk_t(s.handle)))
switch rc {
case _ROW:
s.err = nil
return true
case _DONE:
s.err = nil
default:
s.err = s.c.error(rc)
}
return false
}
// Err gets the last error occurred during [Stmt.Step].
// Err returns nil after [Stmt.Reset] is called.
//
// https://sqlite.org/c3ref/step.html
func (s *Stmt) Err() error {
return s.err
}
// Exec is a convenience function that repeatedly calls [Stmt.Step] until it returns false,
// then calls [Stmt.Reset] to reset the statement and get any error that occurred.
func (s *Stmt) Exec() error {
if s.c.interrupt.Err() != nil {
return INTERRUPT
}
// TODO: implement this in C.
for s.step() {
}
return s.Reset()
}
// Status monitors the performance characteristics of prepared statements.
//
// https://sqlite.org/c3ref/stmt_status.html
func (s *Stmt) Status(op StmtStatus, reset bool) int {
if op > STMTSTATUS_FILTER_HIT && op != STMTSTATUS_MEMUSED {
return 0
}
var i int32
if reset {
i = 1
}
n := int32(s.c.call("sqlite3_stmt_status", stk_t(s.handle),
stk_t(op), stk_t(i)))
return int(n)
}
// ClearBindings resets all bindings on the prepared statement.
//
// https://sqlite.org/c3ref/clear_bindings.html
func (s *Stmt) ClearBindings() error {
rc := res_t(s.c.call("sqlite3_clear_bindings", stk_t(s.handle)))
return s.c.error(rc)
}
// BindCount returns the number of SQL parameters in the prepared statement.
//
// https://sqlite.org/c3ref/bind_parameter_count.html
func (s *Stmt) BindCount() int {
n := int32(s.c.call("sqlite3_bind_parameter_count",
stk_t(s.handle)))
return int(n)
}
// BindIndex returns the index of a parameter in the prepared statement
// given its name.
//
// https://sqlite.org/c3ref/bind_parameter_index.html
func (s *Stmt) BindIndex(name string) int {
defer s.c.arena.mark()()
namePtr := s.c.arena.string(name)
i := int32(s.c.call("sqlite3_bind_parameter_index",
stk_t(s.handle), stk_t(namePtr)))
return int(i)
}
// BindName returns the name of a parameter in the prepared statement.
// The leftmost SQL parameter has an index of 1.
//
// https://sqlite.org/c3ref/bind_parameter_name.html
func (s *Stmt) BindName(param int) string {
ptr := ptr_t(s.c.call("sqlite3_bind_parameter_name",
stk_t(s.handle), stk_t(param)))
if ptr == 0 {
return ""
}
return util.ReadString(s.c.mod, ptr, _MAX_NAME)
}
// BindBool binds a bool to the prepared statement.
// The leftmost SQL parameter has an index of 1.
// SQLite does not have a separate boolean storage class.
// Instead, boolean values are stored as integers 0 (false) and 1 (true).
//
// https://sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindBool(param int, value bool) error {
var i int64
if value {
i = 1
}
return s.BindInt64(param, i)
}
// BindInt binds an int to the prepared statement.
// The leftmost SQL parameter has an index of 1.
//
// https://sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindInt(param int, value int) error {
return s.BindInt64(param, int64(value))
}
// BindInt64 binds an int64 to the prepared statement.
// The leftmost SQL parameter has an index of 1.
//
// https://sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindInt64(param int, value int64) error {
rc := res_t(s.c.call("sqlite3_bind_int64",
stk_t(s.handle), stk_t(param), stk_t(value)))
return s.c.error(rc)
}
// BindFloat binds a float64 to the prepared statement.
// The leftmost SQL parameter has an index of 1.
//
// https://sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindFloat(param int, value float64) error {
rc := res_t(s.c.call("sqlite3_bind_double",
stk_t(s.handle), stk_t(param),
stk_t(math.Float64bits(value))))
return s.c.error(rc)
}
// BindText binds a string to the prepared statement.
// The leftmost SQL parameter has an index of 1.
//
// https://sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindText(param int, value string) error {
if len(value) > _MAX_LENGTH {
return TOOBIG
}
ptr := s.c.newString(value)
rc := res_t(s.c.call("sqlite3_bind_text_go",
stk_t(s.handle), stk_t(param),
stk_t(ptr), stk_t(len(value))))
return s.c.error(rc)
}
// BindRawText binds a []byte to the prepared statement as text.
// The leftmost SQL parameter has an index of 1.
// Binding a nil slice is the same as calling [Stmt.BindNull].
//
// https://sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindRawText(param int, value []byte) error {
if len(value) > _MAX_LENGTH {
return TOOBIG
}
ptr := s.c.newBytes(value)
rc := res_t(s.c.call("sqlite3_bind_text_go",
stk_t(s.handle), stk_t(param),
stk_t(ptr), stk_t(len(value))))
return s.c.error(rc)
}
// BindBlob binds a []byte to the prepared statement.
// The leftmost SQL parameter has an index of 1.
// Binding a nil slice is the same as calling [Stmt.BindNull].
//
// https://sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindBlob(param int, value []byte) error {
if len(value) > _MAX_LENGTH {
return TOOBIG
}
ptr := s.c.newBytes(value)
rc := res_t(s.c.call("sqlite3_bind_blob_go",
stk_t(s.handle), stk_t(param),
stk_t(ptr), stk_t(len(value))))
return s.c.error(rc)
}
// BindZeroBlob binds a zero-filled, length n BLOB to the prepared statement.
// The leftmost SQL parameter has an index of 1.
//
// https://sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindZeroBlob(param int, n int64) error {
rc := res_t(s.c.call("sqlite3_bind_zeroblob64",
stk_t(s.handle), stk_t(param), stk_t(n)))
return s.c.error(rc)
}
// BindNull binds a NULL to the prepared statement.
// The leftmost SQL parameter has an index of 1.
//
// https://sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindNull(param int) error {
rc := res_t(s.c.call("sqlite3_bind_null",
stk_t(s.handle), stk_t(param)))
return s.c.error(rc)
}
// BindTime binds a [time.Time] to the prepared statement.
// The leftmost SQL parameter has an index of 1.
//
// https://sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindTime(param int, value time.Time, format TimeFormat) error {
switch format {
case TimeFormatDefault, TimeFormatAuto, time.RFC3339Nano:
return s.bindRFC3339Nano(param, value)
}
switch v := format.Encode(value).(type) {
case string:
return s.BindText(param, v)
case int64:
return s.BindInt64(param, v)
case float64:
return s.BindFloat(param, v)
default:
panic(util.AssertErr())
}
}
func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error {
const maxlen = int64(len(time.RFC3339Nano)) + 5
ptr := s.c.new(maxlen)
buf := util.View(s.c.mod, ptr, maxlen)
buf = value.AppendFormat(buf[:0], time.RFC3339Nano)
rc := res_t(s.c.call("sqlite3_bind_text_go",
stk_t(s.handle), stk_t(param),
stk_t(ptr), stk_t(len(buf))))
return s.c.error(rc)
}
// BindPointer binds a NULL to the prepared statement, just like [Stmt.BindNull],
// but it also associates ptr with that NULL value such that it can be retrieved
// within an application-defined SQL function using [Value.Pointer].
// The leftmost SQL parameter has an index of 1.
//
// https://sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindPointer(param int, ptr any) error {
valPtr := util.AddHandle(s.c.ctx, ptr)
rc := res_t(s.c.call("sqlite3_bind_pointer_go",
stk_t(s.handle), stk_t(param), stk_t(valPtr)))
return s.c.error(rc)
}
// BindJSON binds the JSON encoding of value to the prepared statement.
// The leftmost SQL parameter has an index of 1.
//
// https://sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindJSON(param int, value any) error {
data, err := json.Marshal(value)
if err != nil {
return err
}
return s.BindRawText(param, data)
}
// BindValue binds a copy of value to the prepared statement.
// The leftmost SQL parameter has an index of 1.
//
// https://sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindValue(param int, value Value) error {
if value.c != s.c {
return MISUSE
}
rc := res_t(s.c.call("sqlite3_bind_value",
stk_t(s.handle), stk_t(param), stk_t(value.handle)))
return s.c.error(rc)
}
// DataCount resets the number of columns in a result set.
//
// https://sqlite.org/c3ref/data_count.html
func (s *Stmt) DataCount() int {
n := int32(s.c.call("sqlite3_data_count",
stk_t(s.handle)))
return int(n)
}
// ColumnCount returns the number of columns in a result set.
//
// https://sqlite.org/c3ref/column_count.html
func (s *Stmt) ColumnCount() int {
n := int32(s.c.call("sqlite3_column_count",
stk_t(s.handle)))
return int(n)
}
// ColumnName returns the name of the result column.
// The leftmost column of the result set has the index 0.
//
// https://sqlite.org/c3ref/column_name.html
func (s *Stmt) ColumnName(col int) string {
ptr := ptr_t(s.c.call("sqlite3_column_name",
stk_t(s.handle), stk_t(col)))
if ptr == 0 {
panic(util.OOMErr)
}
return util.ReadString(s.c.mod, ptr, _MAX_NAME)
}
// ColumnType returns the initial [Datatype] of the result column.
// The leftmost column of the result set has the index 0.
//
// https://sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnType(col int) Datatype {
return Datatype(s.c.call("sqlite3_column_type",
stk_t(s.handle), stk_t(col)))
}
// ColumnDeclType returns the declared datatype of the result column.
// The leftmost column of the result set has the index 0.
//
// https://sqlite.org/c3ref/column_decltype.html
func (s *Stmt) ColumnDeclType(col int) string {
ptr := ptr_t(s.c.call("sqlite3_column_decltype",
stk_t(s.handle), stk_t(col)))
if ptr == 0 {
return ""
}
return util.ReadString(s.c.mod, ptr, _MAX_NAME)
}
// ColumnDatabaseName returns the name of the database
// that is the origin of a particular result column.
// The leftmost column of the result set has the index 0.
//
// https://sqlite.org/c3ref/column_database_name.html
func (s *Stmt) ColumnDatabaseName(col int) string {
ptr := ptr_t(s.c.call("sqlite3_column_database_name",
stk_t(s.handle), stk_t(col)))
if ptr == 0 {
return ""
}
return util.ReadString(s.c.mod, ptr, _MAX_NAME)
}
// ColumnTableName returns the name of the table
// that is the origin of a particular result column.
// The leftmost column of the result set has the index 0.
//
// https://sqlite.org/c3ref/column_database_name.html
func (s *Stmt) ColumnTableName(col int) string {
ptr := ptr_t(s.c.call("sqlite3_column_table_name",
stk_t(s.handle), stk_t(col)))
if ptr == 0 {
return ""
}
return util.ReadString(s.c.mod, ptr, _MAX_NAME)
}
// ColumnOriginName returns the name of the table column
// that is the origin of a particular result column.
// The leftmost column of the result set has the index 0.
//
// https://sqlite.org/c3ref/column_database_name.html
func (s *Stmt) ColumnOriginName(col int) string {
ptr := ptr_t(s.c.call("sqlite3_column_origin_name",
stk_t(s.handle), stk_t(col)))
if ptr == 0 {
return ""
}
return util.ReadString(s.c.mod, ptr, _MAX_NAME)
}
// ColumnBool returns the value of the result column as a bool.
// The leftmost column of the result set has the index 0.
// SQLite does not have a separate boolean storage class.
// Instead, boolean values are retrieved as numbers,
// with 0 converted to false and any other value to true.
//
// https://sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnBool(col int) bool {
return s.ColumnFloat(col) != 0
}
// ColumnInt returns the value of the result column as an int.
// The leftmost column of the result set has the index 0.
//
// https://sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnInt(col int) int {
return int(s.ColumnInt64(col))
}
// ColumnInt64 returns the value of the result column as an int64.
// The leftmost column of the result set has the index 0.
//
// https://sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnInt64(col int) int64 {
return int64(s.c.call("sqlite3_column_int64",
stk_t(s.handle), stk_t(col)))
}
// ColumnFloat returns the value of the result column as a float64.
// The leftmost column of the result set has the index 0.
//
// https://sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnFloat(col int) float64 {
f := uint64(s.c.call("sqlite3_column_double",
stk_t(s.handle), stk_t(col)))
return math.Float64frombits(f)
}
// ColumnTime returns the value of the result column as a [time.Time].
// The leftmost column of the result set has the index 0.
//
// https://sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnTime(col int, format TimeFormat) time.Time {
var v any
switch s.ColumnType(col) {
case INTEGER:
v = s.ColumnInt64(col)
case FLOAT:
v = s.ColumnFloat(col)
case TEXT, BLOB:
v = s.ColumnText(col)
case NULL:
return time.Time{}
default:
panic(util.AssertErr())
}
t, err := format.Decode(v)
if err != nil {
s.err = err
}
return t
}
// ColumnText returns the value of the result column as a string.
// The leftmost column of the result set has the index 0.
//
// https://sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnText(col int) string {
return string(s.ColumnRawText(col))
}
// ColumnBlob appends to buf and returns
// the value of the result column as a []byte.
// The leftmost column of the result set has the index 0.
//
// https://sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnBlob(col int, buf []byte) []byte {
return append(buf, s.ColumnRawBlob(col)...)
}
// ColumnRawText returns the value of the result column as a []byte.
// The []byte is owned by SQLite and may be invalidated by
// subsequent calls to [Stmt] methods.
// The leftmost column of the result set has the index 0.
//
// https://sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnRawText(col int) []byte {
ptr := ptr_t(s.c.call("sqlite3_column_text",
stk_t(s.handle), stk_t(col)))
return s.columnRawBytes(col, ptr)
}
// ColumnRawBlob returns the value of the result column as a []byte.
// The []byte is owned by SQLite and may be invalidated by
// subsequent calls to [Stmt] methods.
// The leftmost column of the result set has the index 0.
//
// https://sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnRawBlob(col int) []byte {
ptr := ptr_t(s.c.call("sqlite3_column_blob",
stk_t(s.handle), stk_t(col)))
return s.columnRawBytes(col, ptr)
}
func (s *Stmt) columnRawBytes(col int, ptr ptr_t) []byte {
if ptr == 0 {
rc := res_t(s.c.call("sqlite3_errcode", stk_t(s.c.handle)))
if rc != _ROW && rc != _DONE {
s.err = s.c.error(rc)
}
return nil
}
n := int32(s.c.call("sqlite3_column_bytes",
stk_t(s.handle), stk_t(col)))
return util.View(s.c.mod, ptr, int64(n))
}
// ColumnJSON parses the JSON-encoded value of the result column
// and stores it in the value pointed to by ptr.
// The leftmost column of the result set has the index 0.
//
// https://sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnJSON(col int, ptr any) error {
var data []byte
switch s.ColumnType(col) {
case NULL:
data = []byte("null")
case TEXT:
data = s.ColumnRawText(col)
case BLOB:
data = s.ColumnRawBlob(col)
case INTEGER:
data = strconv.AppendInt(nil, s.ColumnInt64(col), 10)
case FLOAT:
data = util.AppendNumber(nil, s.ColumnFloat(col))
default:
panic(util.AssertErr())
}
return json.Unmarshal(data, ptr)
}
// ColumnValue returns the unprotected value of the result column.
// The leftmost column of the result set has the index 0.
//
// https://sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnValue(col int) Value {
ptr := ptr_t(s.c.call("sqlite3_column_value",
stk_t(s.handle), stk_t(col)))
return Value{
c: s.c,
unprot: true,
handle: ptr,
}
}
// Columns populates result columns into the provided slice.
// The slice must have [Stmt.ColumnCount] length.
//
// [INTEGER] columns will be retrieved as int64 values,
// [FLOAT] as float64, [NULL] as nil,
// [TEXT] as string, and [BLOB] as []byte.
// Any []byte are owned by SQLite and may be invalidated by
// subsequent calls to [Stmt] methods.
func (s *Stmt) Columns(dest ...any) error {
defer s.c.arena.mark()()
count := int64(len(dest))
typePtr := s.c.arena.new(count)
dataPtr := s.c.arena.new(count * 8)
rc := res_t(s.c.call("sqlite3_columns_go",
stk_t(s.handle), stk_t(count), stk_t(typePtr), stk_t(dataPtr)))
if err := s.c.error(rc); err != nil {
return err
}
types := util.View(s.c.mod, typePtr, count)
// Avoid bounds checks on types below.
if len(types) != len(dest) {
panic(util.AssertErr())
}
for i := range dest {
switch types[i] {
case byte(INTEGER):
dest[i] = util.Read64[int64](s.c.mod, dataPtr)
case byte(FLOAT):
dest[i] = util.ReadFloat64(s.c.mod, dataPtr)
case byte(NULL):
dest[i] = nil
default:
len := util.Read32[int32](s.c.mod, dataPtr+4)
if len != 0 {
ptr := util.Read32[ptr_t](s.c.mod, dataPtr)
buf := util.View(s.c.mod, ptr, int64(len))
if types[i] == byte(TEXT) {
dest[i] = string(buf)
} else {
dest[i] = buf
}
} else {
if types[i] == byte(TEXT) {
dest[i] = ""
} else {
dest[i] = []byte{}
}
}
}
dataPtr += 8
}
return nil
}
package sqlite3
import (
"math"
"strconv"
"strings"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/julianday"
)
// TimeFormat specifies how to encode/decode time values.
//
// See the documentation for the [TimeFormatDefault] constant
// for formats recognized by SQLite.
//
// https://sqlite.org/lang_datefunc.html
type TimeFormat string
// TimeFormats recognized by SQLite to encode/decode time values.
//
// https://sqlite.org/lang_datefunc.html#time_values
const (
TimeFormatDefault TimeFormat = "" // time.RFC3339Nano
// Text formats
TimeFormat1 TimeFormat = "2006-01-02"
TimeFormat2 TimeFormat = "2006-01-02 15:04"
TimeFormat3 TimeFormat = "2006-01-02 15:04:05"
TimeFormat4 TimeFormat = "2006-01-02 15:04:05.000"
TimeFormat5 TimeFormat = "2006-01-02T15:04"
TimeFormat6 TimeFormat = "2006-01-02T15:04:05"
TimeFormat7 TimeFormat = "2006-01-02T15:04:05.000"
TimeFormat8 TimeFormat = "15:04"
TimeFormat9 TimeFormat = "15:04:05"
TimeFormat10 TimeFormat = "15:04:05.000"
TimeFormat2TZ = TimeFormat2 + "Z07:00"
TimeFormat3TZ = TimeFormat3 + "Z07:00"
TimeFormat4TZ = TimeFormat4 + "Z07:00"
TimeFormat5TZ = TimeFormat5 + "Z07:00"
TimeFormat6TZ = TimeFormat6 + "Z07:00"
TimeFormat7TZ = TimeFormat7 + "Z07:00"
TimeFormat8TZ = TimeFormat8 + "Z07:00"
TimeFormat9TZ = TimeFormat9 + "Z07:00"
TimeFormat10TZ = TimeFormat10 + "Z07:00"
// Numeric formats
TimeFormatJulianDay TimeFormat = "julianday"
TimeFormatUnix TimeFormat = "unixepoch"
TimeFormatUnixFrac TimeFormat = "unixepoch_frac"
TimeFormatUnixMilli TimeFormat = "unixepoch_milli" // not an SQLite format
TimeFormatUnixMicro TimeFormat = "unixepoch_micro" // not an SQLite format
TimeFormatUnixNano TimeFormat = "unixepoch_nano" // not an SQLite format
// Auto
TimeFormatAuto TimeFormat = "auto"
)
// Encode encodes a time value using this format.
//
// [TimeFormatDefault] and [TimeFormatAuto] encode using [time.RFC3339Nano],
// with nanosecond accuracy, and preserving any timezone offset.
//
// This is the format used by the [database/sql] driver:
// [database/sql.Row.Scan] will decode as [time.Time]
// values encoded with [time.RFC3339Nano].
//
// Time values encoded with [time.RFC3339Nano] cannot be sorted as strings
// to produce a time-ordered sequence.
//
// Assuming that the time zones of the time values are the same (e.g., all in UTC),
// and expressed using the same string (e.g., all "Z" or all "+00:00"),
// use the TIME [collating sequence] to produce a time-ordered sequence.
//
// Otherwise, use [TimeFormat7] for time-ordered encoding.
//
// Formats [TimeFormat1] through [TimeFormat10]
// convert time values to UTC before encoding.
//
// Returns a string for the text formats,
// a float64 for [TimeFormatJulianDay] and [TimeFormatUnixFrac],
// or an int64 for the other numeric formats.
//
// https://sqlite.org/lang_datefunc.html
//
// [collating sequence]: https://sqlite.org/datatype3.html#collating_sequences
func (f TimeFormat) Encode(t time.Time) any {
switch f {
// Numeric formats
case TimeFormatJulianDay:
return julianday.Float(t)
case TimeFormatUnix:
return t.Unix()
case TimeFormatUnixFrac:
return float64(t.Unix()) + float64(t.Nanosecond())*1e-9
case TimeFormatUnixMilli:
return t.UnixMilli()
case TimeFormatUnixMicro:
return t.UnixMicro()
case TimeFormatUnixNano:
return t.UnixNano()
// Special formats.
case TimeFormatDefault, TimeFormatAuto:
f = time.RFC3339Nano
// SQLite assumes UTC if unspecified.
case
TimeFormat1, TimeFormat2,
TimeFormat3, TimeFormat4,
TimeFormat5, TimeFormat6,
TimeFormat7, TimeFormat8,
TimeFormat9, TimeFormat10:
t = t.UTC()
}
return t.Format(string(f))
}
// Decode decodes a time value using this format.
//
// The time value can be a string, an int64, or a float64.
//
// Formats [TimeFormat8] through [TimeFormat10]
// (and [TimeFormat8TZ] through [TimeFormat10TZ])
// assume a date of 2000-01-01.
//
// The timezone indicator and fractional seconds are always optional
// for formats [TimeFormat2] through [TimeFormat10]
// (and [TimeFormat2TZ] through [TimeFormat10TZ]).
//
// [TimeFormatAuto] implements (and extends) the SQLite auto modifier.
// Julian day numbers are safe to use for historical dates,
// from 4712BC through 9999AD.
// Unix timestamps (expressed in seconds, milliseconds, microseconds, or nanoseconds)
// are safe to use for current events, from at least 1980 through at least 2260.
// Unix timestamps before 1980 and after 9999 may be misinterpreted as julian day numbers,
// or have the wrong time unit.
//
// https://sqlite.org/lang_datefunc.html
func (f TimeFormat) Decode(v any) (time.Time, error) {
if t, ok := v.(time.Time); ok {
return t, nil
}
switch f {
// Numeric formats.
case TimeFormatJulianDay:
switch v := v.(type) {
case string:
return julianday.Parse(v)
case float64:
return julianday.FloatTime(v), nil
case int64:
return julianday.Time(v, 0), nil
default:
return time.Time{}, util.TimeErr
}
case TimeFormatUnix, TimeFormatUnixFrac:
if s, ok := v.(string); ok {
f, err := strconv.ParseFloat(s, 64)
if err != nil {
return time.Time{}, err
}
v = f
}
switch v := v.(type) {
case float64:
sec, frac := math.Modf(v)
nsec := math.Floor(frac * 1e9)
return time.Unix(int64(sec), int64(nsec)).UTC(), nil
case int64:
return time.Unix(v, 0).UTC(), nil
default:
return time.Time{}, util.TimeErr
}
case TimeFormatUnixMilli:
if s, ok := v.(string); ok {
i, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return time.Time{}, err
}
v = i
}
switch v := v.(type) {
case float64:
return time.UnixMilli(int64(math.Floor(v))).UTC(), nil
case int64:
return time.UnixMilli(v).UTC(), nil
default:
return time.Time{}, util.TimeErr
}
case TimeFormatUnixMicro:
if s, ok := v.(string); ok {
i, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return time.Time{}, err
}
v = i
}
switch v := v.(type) {
case float64:
return time.UnixMicro(int64(math.Floor(v))).UTC(), nil
case int64:
return time.UnixMicro(v).UTC(), nil
default:
return time.Time{}, util.TimeErr
}
case TimeFormatUnixNano:
if s, ok := v.(string); ok {
i, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return time.Time{}, util.TimeErr
}
v = i
}
switch v := v.(type) {
case float64:
return time.Unix(0, int64(math.Floor(v))).UTC(), nil
case int64:
return time.Unix(0, v).UTC(), nil
default:
return time.Time{}, util.TimeErr
}
// Special formats.
case TimeFormatAuto:
switch s := v.(type) {
case string:
i, err := strconv.ParseInt(s, 10, 64)
if err == nil {
v = i
break
}
f, err := strconv.ParseFloat(s, 64)
if err == nil {
v = f
break
}
dates := []TimeFormat{
TimeFormat9, TimeFormat8,
TimeFormat6, TimeFormat5,
TimeFormat3, TimeFormat2, TimeFormat1,
}
for _, f := range dates {
t, err := f.Decode(s)
if err == nil {
return t, nil
}
}
}
switch v := v.(type) {
case float64:
if 0 <= v && v < 5373484.5 {
return TimeFormatJulianDay.Decode(v)
}
if v < 253402300800 {
return TimeFormatUnixFrac.Decode(v)
}
if v < 253402300800_000 {
return TimeFormatUnixMilli.Decode(v)
}
if v < 253402300800_000000 {
return TimeFormatUnixMicro.Decode(v)
}
return TimeFormatUnixNano.Decode(v)
case int64:
if 0 <= v && v < 5373485 {
return TimeFormatJulianDay.Decode(v)
}
if v < 253402300800 {
return TimeFormatUnixFrac.Decode(v)
}
if v < 253402300800_000 {
return TimeFormatUnixMilli.Decode(v)
}
if v < 253402300800_000000 {
return TimeFormatUnixMicro.Decode(v)
}
return TimeFormatUnixNano.Decode(v)
default:
return time.Time{}, util.TimeErr
}
case
TimeFormat2, TimeFormat2TZ,
TimeFormat3, TimeFormat3TZ,
TimeFormat4, TimeFormat4TZ,
TimeFormat5, TimeFormat5TZ,
TimeFormat6, TimeFormat6TZ,
TimeFormat7, TimeFormat7TZ:
s, ok := v.(string)
if !ok {
return time.Time{}, util.TimeErr
}
return f.parseRelaxed(s)
case
TimeFormat8, TimeFormat8TZ,
TimeFormat9, TimeFormat9TZ,
TimeFormat10, TimeFormat10TZ:
s, ok := v.(string)
if !ok {
return time.Time{}, util.TimeErr
}
t, err := f.parseRelaxed(s)
if err != nil {
return time.Time{}, err
}
return t.AddDate(2000, 0, 0), nil
default:
s, ok := v.(string)
if !ok {
return time.Time{}, util.TimeErr
}
if f == "" {
f = time.RFC3339Nano
}
return time.Parse(string(f), s)
}
}
func (f TimeFormat) parseRelaxed(s string) (time.Time, error) {
fs := string(f)
fs = strings.TrimSuffix(fs, "Z07:00")
fs = strings.TrimSuffix(fs, ".000")
t, err := time.Parse(fs+"Z07:00", s)
if err != nil {
return time.Parse(fs, s)
}
return t, nil
}
// Scanner returns a [database/sql.Scanner] that can be used as an argument to
// [database/sql.Row.Scan] and similar methods to
// decode a time value into dest using this format.
func (f TimeFormat) Scanner(dest *time.Time) interface{ Scan(any) error } {
return timeScanner{dest, f}
}
type timeScanner struct {
*time.Time
TimeFormat
}
func (s timeScanner) Scan(src any) error {
var ok bool
var err error
if *s.Time, ok = src.(time.Time); !ok {
*s.Time, err = s.Decode(src)
}
return err
}
package sqlite3
import (
"context"
"math/rand"
"runtime"
"strconv"
"strings"
"github.com/tetratelabs/wazero/api"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Txn is an in-progress database transaction.
//
// https://sqlite.org/lang_transaction.html
type Txn struct {
c *Conn
}
// Begin starts a deferred transaction.
// Panics if a transaction is already in-progress.
// For nested transactions, use [Conn.Savepoint].
//
// https://sqlite.org/lang_transaction.html
func (c *Conn) Begin() Txn {
// BEGIN even if interrupted.
err := c.exec(`BEGIN DEFERRED`)
if err != nil {
panic(err)
}
return Txn{c}
}
// BeginConcurrent starts a concurrent transaction.
//
// Experimental: requires a custom build of SQLite.
//
// https://sqlite.org/cgi/src/doc/begin-concurrent/doc/begin_concurrent.md
func (c *Conn) BeginConcurrent() (Txn, error) {
err := c.Exec(`BEGIN CONCURRENT`)
if err != nil {
return Txn{}, err
}
return Txn{c}, nil
}
// BeginImmediate starts an immediate transaction.
//
// https://sqlite.org/lang_transaction.html
func (c *Conn) BeginImmediate() (Txn, error) {
err := c.Exec(`BEGIN IMMEDIATE`)
if err != nil {
return Txn{}, err
}
return Txn{c}, nil
}
// BeginExclusive starts an exclusive transaction.
//
// https://sqlite.org/lang_transaction.html
func (c *Conn) BeginExclusive() (Txn, error) {
err := c.Exec(`BEGIN EXCLUSIVE`)
if err != nil {
return Txn{}, err
}
return Txn{c}, nil
}
// End calls either [Txn.Commit] or [Txn.Rollback]
// depending on whether *error points to a nil or non-nil error.
//
// This is meant to be deferred:
//
// func doWork(db *sqlite3.Conn) (err error) {
// tx := db.Begin()
// defer tx.End(&err)
//
// // ... do work in the transaction
// }
//
// https://sqlite.org/lang_transaction.html
func (tx Txn) End(errp *error) {
recovered := recover()
if recovered != nil {
defer panic(recovered)
}
if *errp == nil && recovered == nil {
// Success path.
if tx.c.GetAutocommit() { // There is nothing to commit.
return
}
*errp = tx.Commit()
if *errp == nil {
return
}
// Fall through to the error path.
}
// Error path.
if tx.c.GetAutocommit() { // There is nothing to rollback.
return
}
err := tx.Rollback()
if err != nil {
panic(err)
}
}
// Commit commits the transaction.
//
// https://sqlite.org/lang_transaction.html
func (tx Txn) Commit() error {
return tx.c.Exec(`COMMIT`)
}
// Rollback rolls back the transaction,
// even if the connection has been interrupted.
//
// https://sqlite.org/lang_transaction.html
func (tx Txn) Rollback() error {
// ROLLBACK even if interrupted.
return tx.c.exec(`ROLLBACK`)
}
// Savepoint is a marker within a transaction
// that allows for partial rollback.
//
// https://sqlite.org/lang_savepoint.html
type Savepoint struct {
c *Conn
name string
}
// Savepoint establishes a new transaction savepoint.
//
// https://sqlite.org/lang_savepoint.html
func (c *Conn) Savepoint() Savepoint {
name := callerName()
if name == "" {
name = "sqlite3.Savepoint"
}
// Names can be reused, but this makes catching bugs more likely.
name = QuoteIdentifier(name + "_" + strconv.Itoa(int(rand.Int31())))
err := c.exec(`SAVEPOINT ` + name)
if err != nil {
panic(err)
}
return Savepoint{c: c, name: name}
}
func callerName() (name string) {
var pc [8]uintptr
n := runtime.Callers(3, pc[:])
if n <= 0 {
return ""
}
frames := runtime.CallersFrames(pc[:n])
frame, more := frames.Next()
for more && (strings.HasPrefix(frame.Function, "database/sql.") ||
strings.HasPrefix(frame.Function, "github.com/ncruces/go-sqlite3/driver.")) {
frame, more = frames.Next()
}
return frame.Function
}
// Release releases the savepoint rolling back any changes
// if *error points to a non-nil error.
//
// This is meant to be deferred:
//
// func doWork(db *sqlite3.Conn) (err error) {
// savept := db.Savepoint()
// defer savept.Release(&err)
//
// // ... do work in the transaction
// }
func (s Savepoint) Release(errp *error) {
recovered := recover()
if recovered != nil {
defer panic(recovered)
}
if *errp == nil && recovered == nil {
// Success path.
if s.c.GetAutocommit() { // There is nothing to commit.
return
}
*errp = s.c.Exec(`RELEASE ` + s.name)
if *errp == nil {
return
}
// Fall through to the error path.
}
// Error path.
if s.c.GetAutocommit() { // There is nothing to rollback.
return
}
// ROLLBACK and RELEASE even if interrupted.
err := s.c.exec(`ROLLBACK TO ` + s.name + `; RELEASE ` + s.name)
if err != nil {
panic(err)
}
}
// Rollback rolls the transaction back to the savepoint,
// even if the connection has been interrupted.
// Rollback does not release the savepoint.
//
// https://sqlite.org/lang_transaction.html
func (s Savepoint) Rollback() error {
// ROLLBACK even if interrupted.
return s.c.exec(`ROLLBACK TO ` + s.name)
}
// TxnState determines the transaction state of a database.
//
// https://sqlite.org/c3ref/txn_state.html
func (c *Conn) TxnState(schema string) TxnState {
var ptr ptr_t
if schema != "" {
defer c.arena.mark()()
ptr = c.arena.string(schema)
}
return TxnState(c.call("sqlite3_txn_state", stk_t(c.handle), stk_t(ptr)))
}
// CommitHook registers a callback function to be invoked
// whenever a transaction is committed.
// Return true to allow the commit operation to continue normally.
//
// https://sqlite.org/c3ref/commit_hook.html
func (c *Conn) CommitHook(cb func() (ok bool)) {
var enable int32
if cb != nil {
enable = 1
}
c.call("sqlite3_commit_hook_go", stk_t(c.handle), stk_t(enable))
c.commit = cb
}
// RollbackHook registers a callback function to be invoked
// whenever a transaction is rolled back.
//
// https://sqlite.org/c3ref/commit_hook.html
func (c *Conn) RollbackHook(cb func()) {
var enable int32
if cb != nil {
enable = 1
}
c.call("sqlite3_rollback_hook_go", stk_t(c.handle), stk_t(enable))
c.rollback = cb
}
// UpdateHook registers a callback function to be invoked
// whenever a row is updated, inserted or deleted in a rowid table.
//
// https://sqlite.org/c3ref/update_hook.html
func (c *Conn) UpdateHook(cb func(action AuthorizerActionCode, schema, table string, rowid int64)) {
var enable int32
if cb != nil {
enable = 1
}
c.call("sqlite3_update_hook_go", stk_t(c.handle), stk_t(enable))
c.update = cb
}
func commitCallback(ctx context.Context, mod api.Module, pDB ptr_t) (rollback int32) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.commit != nil {
if !c.commit() {
rollback = 1
}
}
return rollback
}
func rollbackCallback(ctx context.Context, mod api.Module, pDB ptr_t) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.rollback != nil {
c.rollback()
}
}
func updateCallback(ctx context.Context, mod api.Module, pDB ptr_t, action AuthorizerActionCode, zSchema, zTabName ptr_t, rowid int64) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.update != nil {
schema := util.ReadString(mod, zSchema, _MAX_NAME)
table := util.ReadString(mod, zTabName, _MAX_NAME)
c.update(action, schema, table, rowid)
}
}
// CacheFlush flushes caches to disk mid-transaction.
//
// https://sqlite.org/c3ref/db_cacheflush.html
func (c *Conn) CacheFlush() error {
rc := res_t(c.call("sqlite3_db_cacheflush", stk_t(c.handle)))
return c.error(rc)
}
// Package fsutil implements filesystem utilities.
package fsutil
import (
"io/fs"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
// ParseFileMode parses a file mode as returned by
// [fs.FileMode.String].
func ParseFileMode(str string) (fs.FileMode, error) {
var mode fs.FileMode
err := util.ErrorString("invalid mode: " + str)
if len(str) < 10 {
return 0, err
}
for i, c := range []byte("dalTLDpSugct?") {
if str[0] == c {
if len(str) < 10 {
return 0, err
}
mode |= 1 << uint(32-1-i)
str = str[1:]
}
}
if mode == 0 {
if str[0] != '-' {
return 0, err
}
str = str[1:]
}
if len(str) != 9 {
return 0, err
}
for i, c := range []byte("rwxrwxrwx") {
if str[i] == c {
mode |= 1 << uint(9-1-i)
}
if str[i] != '-' {
return 0, err
}
}
return mode, nil
}
// FileModeFromUnix converts a POSIX mode_t to a file mode.
func FileModeFromUnix(mode fs.FileMode) fs.FileMode {
const (
S_IFMT fs.FileMode = 0170000
S_IFIFO fs.FileMode = 0010000
S_IFCHR fs.FileMode = 0020000
S_IFDIR fs.FileMode = 0040000
S_IFBLK fs.FileMode = 0060000
S_IFREG fs.FileMode = 0100000
S_IFLNK fs.FileMode = 0120000
S_IFSOCK fs.FileMode = 0140000
)
switch mode & S_IFMT {
case S_IFDIR:
mode |= fs.ModeDir
case S_IFLNK:
mode |= fs.ModeSymlink
case S_IFBLK:
mode |= fs.ModeDevice
case S_IFCHR:
mode |= fs.ModeCharDevice | fs.ModeDevice
case S_IFIFO:
mode |= fs.ModeNamedPipe
case S_IFSOCK:
mode |= fs.ModeSocket
case S_IFREG, 0:
//
default:
mode |= fs.ModeIrregular
}
return mode &^ S_IFMT
}
// FileModeFromValue calls [FileModeFromUnix] for numeric values,
// and [ParseFileMode] for textual values.
func FileModeFromValue(val sqlite3.Value) fs.FileMode {
if n := val.Int64(); n != 0 {
return FileModeFromUnix(fs.FileMode(n))
}
mode, _ := ParseFileMode(val.Text())
return mode
}
package ioutil
import (
"io"
"sync"
)
// SeekingReaderAt implements [io.ReaderAt]
// through an underlying [io.ReadSeeker].
type SeekingReaderAt struct {
// +checklocks:l
r io.ReadSeeker
l sync.Mutex
}
// NewSeekingReaderAt creates a new SeekingReaderAt.
// The SeekingReaderAt takes ownership of r
// and will modify its seek offset,
// so callers should not use r after this call.
func NewSeekingReaderAt(r io.ReadSeeker) *SeekingReaderAt {
return &SeekingReaderAt{r: r}
}
// ReadAt implements [io.ReaderAt].
func (s *SeekingReaderAt) ReadAt(p []byte, off int64) (n int, _ error) {
s.l.Lock()
defer s.l.Unlock()
_, err := s.r.Seek(off, io.SeekStart)
if err != nil {
return 0, err
}
for len(p) > 0 {
i, err := s.r.Read(p)
p = p[i:]
n += i
if err != nil {
return n, err
}
}
return n, nil
}
// Size implements [SizeReaderAt].
func (s *SeekingReaderAt) Size() (int64, error) {
s.l.Lock()
defer s.l.Unlock()
return s.r.Seek(0, io.SeekEnd)
}
// ReadAt implements [io.Closer].
func (s *SeekingReaderAt) Close() error {
s.l.Lock()
defer s.l.Unlock()
if c, ok := s.r.(io.Closer); ok {
s.r = nil
return c.Close()
}
return nil
}
package ioutil
import (
"io"
"io/fs"
"github.com/ncruces/go-sqlite3"
)
// A SizeReaderAt is a ReaderAt with a Size method.
// Use [NewSizeReaderAt] to adapt different Size interfaces.
type SizeReaderAt interface {
Size() (int64, error)
io.ReaderAt
}
// NewSizeReaderAt returns a SizeReaderAt given an io.ReaderAt
// that implements one of:
// - Size() (int64, error)
// - Size() int64
// - Len() int
// - Stat() (fs.FileInfo, error)
// - Seek(offset int64, whence int) (int64, error)
func NewSizeReaderAt(r io.ReaderAt) SizeReaderAt {
return sizer{r}
}
type sizer struct{ io.ReaderAt }
func (s sizer) Size() (int64, error) {
switch s := s.ReaderAt.(type) {
case interface{ Size() (int64, error) }:
return s.Size()
case interface{ Size() int64 }:
return s.Size(), nil
case interface{ Len() int }:
return int64(s.Len()), nil
case interface{ Stat() (fs.FileInfo, error) }:
fi, err := s.Stat()
if err != nil {
return 0, err
}
return fi.Size(), nil
case io.Seeker:
return s.Seek(0, io.SeekEnd)
}
return 0, sqlite3.IOERR_SEEK
}
// Package osutil implements operating system utilities.
package osutil
import (
"io/fs"
"os"
)
// FS implements [fs.FS], [fs.StatFS], and [fs.ReadFileFS]
// using package [os].
//
// This filesystem does not respect [fs.ValidPath] rules,
// and fails [testing/fstest.TestFS]!
//
// Still, it can be a useful tool to unify implementations
// that can access either the [os] filesystem or an [fs.FS].
// It's OK to use this to open files, but you should avoid
// opening directories, resolving paths, or walking the file system.
type FS struct{}
// Open implements [fs.FS].
func (FS) Open(name string) (fs.File, error) {
return os.OpenFile(name, os.O_RDONLY, 0)
}
// ReadFileFS implements [fs.StatFS].
func (FS) Stat(name string) (fs.FileInfo, error) {
return os.Stat(name)
}
// ReadFile implements [fs.ReadFileFS].
func (FS) ReadFile(name string) ([]byte, error) {
return os.ReadFile(name)
}
// OpenFile behaves the same as [os.OpenFile].
//
// Deprecated: use os.OpenFile instead.
func OpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
return os.OpenFile(name, flag, perm)
}
package sql3util
import "strings"
// NamedArg splits an named arg into a key and value,
// around an equals sign.
// Spaces are trimmed around both key and value.
func NamedArg(arg string) (key, val string) {
key, val, _ = strings.Cut(arg, "=")
key = strings.TrimSpace(key)
val = strings.TrimSpace(val)
return
}
// Unquote unquotes a string.
//
// https://sqlite.org/lang_keywords.html
func Unquote(val string) string {
if len(val) < 2 {
return val
}
fst := val[0]
lst := val[len(val)-1]
rst := val[1 : len(val)-1]
if fst == '[' && lst == ']' {
return rst
}
if fst != lst {
return val
}
var old, new string
switch fst {
default:
return val
case '`':
old, new = "``", "`"
case '"':
old, new = `""`, `"`
case '\'':
old, new = `''`, `'`
}
return strings.ReplaceAll(rst, old, new)
}
// ParseBool parses a boolean.
//
// https://sqlite.org/pragma.html#syntax
func ParseBool(s string) (b, ok bool) {
if len(s) == 0 {
return false, false
}
if s[0] == '0' {
return false, true
}
if '1' <= s[0] && s[0] <= '9' {
return true, true
}
switch strings.ToLower(s) {
case "true", "yes", "on":
return true, true
case "false", "no", "off":
return false, true
}
return false, false
}
package sql3util
import (
"context"
_ "embed"
"sync"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"github.com/ncruces/go-sqlite3/internal/util"
)
const (
errp = 4
sqlp = 8
)
var (
//go:embed wasm/sql3parse_table.wasm
binary []byte
once sync.Once
runtime wazero.Runtime
compiled wazero.CompiledModule
)
// ParseTable parses a [CREATE] or [ALTER TABLE] command.
//
// [CREATE]: https://sqlite.org/lang_createtable.html
// [ALTER TABLE]: https://sqlite.org/lang_altertable.html
func ParseTable(sql string) (_ *Table, err error) {
once.Do(func() {
ctx := context.Background()
cfg := wazero.NewRuntimeConfigInterpreter()
runtime = wazero.NewRuntimeWithConfig(ctx, cfg)
compiled, err = runtime.CompileModule(ctx, binary)
})
if err != nil {
return nil, err
}
ctx := context.Background()
mod, err := runtime.InstantiateModule(ctx, compiled, wazero.NewModuleConfig().WithName(""))
if err != nil {
return nil, err
}
defer mod.Close(ctx)
if buf, ok := mod.Memory().Read(sqlp, uint32(len(sql))); ok {
copy(buf, sql)
}
stack := [...]util.Stk_t{sqlp, util.Stk_t(len(sql)), errp}
err = mod.ExportedFunction("sql3parse_table").CallWithStack(ctx, stack[:])
if err != nil {
return nil, err
}
c, _ := mod.Memory().ReadUint32Le(errp)
switch c {
case _MEMORY:
panic(util.OOMErr)
case _SYNTAX:
return nil, util.ErrorString("sql3parse: invalid syntax")
case _UNSUPPORTEDSQL:
return nil, util.ErrorString("sql3parse: unsupported SQL")
}
var tab Table
tab.load(mod, uint32(stack[0]), sql)
return &tab, nil
}
// Table holds metadata about a table.
type Table struct {
Name string
Schema string
Comment string
IsTemporary bool
IsIfNotExists bool
IsWithoutRowID bool
IsStrict bool
Columns []Column
Type StatementType
CurrentName string
NewName string
}
func (t *Table) load(mod api.Module, ptr uint32, sql string) {
t.Name = loadString(mod, ptr+0, sql)
t.Schema = loadString(mod, ptr+8, sql)
t.Comment = loadString(mod, ptr+16, sql)
t.IsTemporary = loadBool(mod, ptr+24)
t.IsIfNotExists = loadBool(mod, ptr+25)
t.IsWithoutRowID = loadBool(mod, ptr+26)
t.IsStrict = loadBool(mod, ptr+27)
t.Columns = loadSlice(mod, ptr+28, func(ptr uint32, ret *Column) {
p, _ := mod.Memory().ReadUint32Le(ptr)
ret.load(mod, p, sql)
})
t.Type = loadEnum[StatementType](mod, ptr+44)
t.CurrentName = loadString(mod, ptr+48, sql)
t.NewName = loadString(mod, ptr+56, sql)
}
// Column holds metadata about a column.
type Column struct {
Name string
Type string
Length string
ConstraintName string
Comment string
IsPrimaryKey bool
IsAutoIncrement bool
IsNotNull bool
IsUnique bool
PKOrder OrderClause
PKConflictClause ConflictClause
NotNullConflictClause ConflictClause
UniqueConflictClause ConflictClause
CheckExpr string
DefaultExpr string
CollateName string
ForeignKeyClause *ForeignKey
}
func (c *Column) load(mod api.Module, ptr uint32, sql string) {
c.Name = loadString(mod, ptr+0, sql)
c.Type = loadString(mod, ptr+8, sql)
c.Length = loadString(mod, ptr+16, sql)
c.ConstraintName = loadString(mod, ptr+24, sql)
c.Comment = loadString(mod, ptr+32, sql)
c.IsPrimaryKey = loadBool(mod, ptr+40)
c.IsAutoIncrement = loadBool(mod, ptr+41)
c.IsNotNull = loadBool(mod, ptr+42)
c.IsUnique = loadBool(mod, ptr+43)
c.PKOrder = loadEnum[OrderClause](mod, ptr+44)
c.PKConflictClause = loadEnum[ConflictClause](mod, ptr+48)
c.NotNullConflictClause = loadEnum[ConflictClause](mod, ptr+52)
c.UniqueConflictClause = loadEnum[ConflictClause](mod, ptr+56)
c.CheckExpr = loadString(mod, ptr+60, sql)
c.DefaultExpr = loadString(mod, ptr+68, sql)
c.CollateName = loadString(mod, ptr+76, sql)
if ptr, _ := mod.Memory().ReadUint32Le(ptr + 84); ptr != 0 {
c.ForeignKeyClause = &ForeignKey{}
c.ForeignKeyClause.load(mod, ptr, sql)
}
}
type ForeignKey struct {
Table string
Columns []string
OnDelete FKAction
OnUpdate FKAction
Match string
Deferrable FKDefType
}
func (f *ForeignKey) load(mod api.Module, ptr uint32, sql string) {
f.Table = loadString(mod, ptr+0, sql)
f.Columns = loadSlice(mod, ptr+8, func(ptr uint32, ret *string) {
*ret = loadString(mod, ptr, sql)
})
f.OnDelete = loadEnum[FKAction](mod, ptr+16)
f.OnUpdate = loadEnum[FKAction](mod, ptr+20)
f.Match = loadString(mod, ptr+24, sql)
f.Deferrable = loadEnum[FKDefType](mod, ptr+32)
}
func loadString(mod api.Module, ptr uint32, sql string) string {
off, _ := mod.Memory().ReadUint32Le(ptr + 0)
if off == 0 {
return ""
}
len, _ := mod.Memory().ReadUint32Le(ptr + 4)
return sql[off-sqlp : off+len-sqlp]
}
func loadSlice[T any](mod api.Module, ptr uint32, fn func(uint32, *T)) []T {
ref, _ := mod.Memory().ReadUint32Le(ptr + 4)
if ref == 0 {
return nil
}
len, _ := mod.Memory().ReadUint32Le(ptr + 0)
ret := make([]T, len)
for i := range ret {
fn(ref, &ret[i])
ref += 4
}
return ret
}
func loadEnum[T ~uint32](mod api.Module, ptr uint32) T {
val, _ := mod.Memory().ReadUint32Le(ptr)
return T(val)
}
func loadBool(mod api.Module, ptr uint32) bool {
val, _ := mod.Memory().ReadByte(ptr)
return val != 0
}
// Package sql3util implements SQLite utilities.
package sql3util
// ValidPageSize returns true if s is a valid page size.
//
// https://sqlite.org/fileformat.html#pages
func ValidPageSize(s int) bool {
return s&(s-1) == 0 && 512 <= s && s <= 65536
}
// Package vfsutil implements virtual filesystem utilities.
package vfsutil
import (
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/vfs"
)
// UnwrapFile unwraps a [vfs.File],
// possibly implementing [vfs.FileUnwrap],
// to a concrete type.
func UnwrapFile[T vfs.File](f vfs.File) (_ T, _ bool) {
for {
switch t := f.(type) {
default:
return
case T:
return t, true
case vfs.FileUnwrap:
f = t.Unwrap()
}
}
}
// WrapOpenFilename helps wrap [vfs.VFSFilename].
func WrapOpenFilename(f vfs.VFS, name *vfs.Filename, flags vfs.OpenFlag) (file vfs.File, _ vfs.OpenFlag, err error) {
if f, ok := f.(vfs.VFSFilename); ok {
return f.OpenFilename(name, flags)
}
return f.Open(name.String(), flags)
}
// WrapLockState helps wrap [vfs.FileLockState].
func WrapLockState(f vfs.File) vfs.LockLevel {
if f, ok := f.(vfs.FileLockState); ok {
return f.LockState()
}
return vfs.LOCK_EXCLUSIVE + 1 // UNKNOWN_LOCK
}
// WrapPersistWAL helps wrap [vfs.FilePersistWAL].
func WrapPersistWAL(f vfs.File) bool {
if f, ok := f.(vfs.FilePersistWAL); ok {
return f.PersistWAL()
}
return false
}
// WrapSetPersistWAL helps wrap [vfs.FilePersistWAL].
func WrapSetPersistWAL(f vfs.File, keepWAL bool) {
if f, ok := f.(vfs.FilePersistWAL); ok {
f.SetPersistWAL(keepWAL)
}
}
// WrapPowersafeOverwrite helps wrap [vfs.FilePowersafeOverwrite].
func WrapPowersafeOverwrite(f vfs.File) bool {
if f, ok := f.(vfs.FilePowersafeOverwrite); ok {
return f.PowersafeOverwrite()
}
return false
}
// WrapSetPowersafeOverwrite helps wrap [vfs.FilePowersafeOverwrite].
func WrapSetPowersafeOverwrite(f vfs.File, psow bool) {
if f, ok := f.(vfs.FilePowersafeOverwrite); ok {
f.SetPowersafeOverwrite(psow)
}
}
// WrapChunkSize helps wrap [vfs.FileChunkSize].
func WrapChunkSize(f vfs.File, size int) {
if f, ok := f.(vfs.FileChunkSize); ok {
f.ChunkSize(size)
}
}
// WrapSizeHint helps wrap [vfs.FileSizeHint].
func WrapSizeHint(f vfs.File, size int64) error {
if f, ok := f.(vfs.FileSizeHint); ok {
return f.SizeHint(size)
}
return sqlite3.NOTFOUND
}
// WrapHasMoved helps wrap [vfs.FileHasMoved].
func WrapHasMoved(f vfs.File) (bool, error) {
if f, ok := f.(vfs.FileHasMoved); ok {
return f.HasMoved()
}
return false, sqlite3.NOTFOUND
}
// WrapOverwrite helps wrap [vfs.FileOverwrite].
func WrapOverwrite(f vfs.File) error {
if f, ok := f.(vfs.FileOverwrite); ok {
return f.Overwrite()
}
return sqlite3.NOTFOUND
}
// WrapSyncSuper helps wrap [vfs.FileSync].
func WrapSyncSuper(f vfs.File, super string) error {
if f, ok := f.(vfs.FileSync); ok {
return f.SyncSuper(super)
}
return sqlite3.NOTFOUND
}
// WrapCommitPhaseTwo helps wrap [vfs.FileCommitPhaseTwo].
func WrapCommitPhaseTwo(f vfs.File) error {
if f, ok := f.(vfs.FileCommitPhaseTwo); ok {
return f.CommitPhaseTwo()
}
return sqlite3.NOTFOUND
}
// WrapBeginAtomicWrite helps wrap [vfs.FileBatchAtomicWrite].
func WrapBeginAtomicWrite(f vfs.File) error {
if f, ok := f.(vfs.FileBatchAtomicWrite); ok {
return f.BeginAtomicWrite()
}
return sqlite3.NOTFOUND
}
// WrapCommitAtomicWrite helps wrap [vfs.FileBatchAtomicWrite].
func WrapCommitAtomicWrite(f vfs.File) error {
if f, ok := f.(vfs.FileBatchAtomicWrite); ok {
return f.CommitAtomicWrite()
}
return sqlite3.NOTFOUND
}
// WrapRollbackAtomicWrite helps wrap [vfs.FileBatchAtomicWrite].
func WrapRollbackAtomicWrite(f vfs.File) error {
if f, ok := f.(vfs.FileBatchAtomicWrite); ok {
return f.RollbackAtomicWrite()
}
return sqlite3.NOTFOUND
}
// WrapCheckpointStart helps wrap [vfs.FileCheckpoint].
func WrapCheckpointStart(f vfs.File) {
if f, ok := f.(vfs.FileCheckpoint); ok {
f.CheckpointStart()
}
}
// WrapCheckpointDone helps wrap [vfs.FileCheckpoint].
func WrapCheckpointDone(f vfs.File) {
if f, ok := f.(vfs.FileCheckpoint); ok {
f.CheckpointDone()
}
}
// WrapPragma helps wrap [vfs.FilePragma].
func WrapPragma(f vfs.File, name, value string) (string, error) {
if f, ok := f.(vfs.FilePragma); ok {
return f.Pragma(name, value)
}
return "", sqlite3.NOTFOUND
}
// WrapBusyHandler helps wrap [vfs.FilePragma].
func WrapBusyHandler(f vfs.File, handler func() bool) {
if f, ok := f.(vfs.FileBusyHandler); ok {
f.BusyHandler(handler)
}
}
// WrapSharedMemory helps wrap [vfs.FileSharedMemory].
func WrapSharedMemory(f vfs.File) vfs.SharedMemory {
if f, ok := f.(vfs.FileSharedMemory); ok {
return f.SharedMemory()
}
return nil
}
package sqlite3
import (
"encoding/json"
"math"
"strconv"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Value is any value that can be stored in a database table.
//
// https://sqlite.org/c3ref/value.html
type Value struct {
c *Conn
handle ptr_t
unprot bool
copied bool
}
func (v Value) protected() stk_t {
if v.unprot {
panic(util.ValueErr)
}
return stk_t(v.handle)
}
// Dup makes a copy of the SQL value and returns a pointer to that copy.
//
// https://sqlite.org/c3ref/value_dup.html
func (v Value) Dup() *Value {
ptr := ptr_t(v.c.call("sqlite3_value_dup", stk_t(v.handle)))
return &Value{
c: v.c,
copied: true,
handle: ptr,
}
}
// Close frees an SQL value previously obtained by [Value.Dup].
//
// https://sqlite.org/c3ref/value_dup.html
func (dup *Value) Close() error {
if !dup.copied {
panic(util.ValueErr)
}
dup.c.call("sqlite3_value_free", stk_t(dup.handle))
dup.handle = 0
return nil
}
// Type returns the initial datatype of the value.
//
// https://sqlite.org/c3ref/value_blob.html
func (v Value) Type() Datatype {
return Datatype(v.c.call("sqlite3_value_type", v.protected()))
}
// Type returns the numeric datatype of the value.
//
// https://sqlite.org/c3ref/value_blob.html
func (v Value) NumericType() Datatype {
return Datatype(v.c.call("sqlite3_value_numeric_type", v.protected()))
}
// Bool returns the value as a bool.
// SQLite does not have a separate boolean storage class.
// Instead, boolean values are retrieved as numbers,
// with 0 converted to false and any other value to true.
//
// https://sqlite.org/c3ref/value_blob.html
func (v Value) Bool() bool {
return v.Float() != 0
}
// Int returns the value as an int.
//
// https://sqlite.org/c3ref/value_blob.html
func (v Value) Int() int {
return int(v.Int64())
}
// Int64 returns the value as an int64.
//
// https://sqlite.org/c3ref/value_blob.html
func (v Value) Int64() int64 {
return int64(v.c.call("sqlite3_value_int64", v.protected()))
}
// Float returns the value as a float64.
//
// https://sqlite.org/c3ref/value_blob.html
func (v Value) Float() float64 {
f := uint64(v.c.call("sqlite3_value_double", v.protected()))
return math.Float64frombits(f)
}
// Time returns the value as a [time.Time].
//
// https://sqlite.org/c3ref/value_blob.html
func (v Value) Time(format TimeFormat) time.Time {
var a any
switch v.Type() {
case INTEGER:
a = v.Int64()
case FLOAT:
a = v.Float()
case TEXT, BLOB:
a = v.Text()
case NULL:
return time.Time{}
default:
panic(util.AssertErr())
}
t, _ := format.Decode(a)
return t
}
// Text returns the value as a string.
//
// https://sqlite.org/c3ref/value_blob.html
func (v Value) Text() string {
return string(v.RawText())
}
// Blob appends to buf and returns
// the value as a []byte.
//
// https://sqlite.org/c3ref/value_blob.html
func (v Value) Blob(buf []byte) []byte {
return append(buf, v.RawBlob()...)
}
// RawText returns the value as a []byte.
// The []byte is owned by SQLite and may be invalidated by
// subsequent calls to [Value] methods.
//
// https://sqlite.org/c3ref/value_blob.html
func (v Value) RawText() []byte {
ptr := ptr_t(v.c.call("sqlite3_value_text", v.protected()))
return v.rawBytes(ptr)
}
// RawBlob returns the value as a []byte.
// The []byte is owned by SQLite and may be invalidated by
// subsequent calls to [Value] methods.
//
// https://sqlite.org/c3ref/value_blob.html
func (v Value) RawBlob() []byte {
ptr := ptr_t(v.c.call("sqlite3_value_blob", v.protected()))
return v.rawBytes(ptr)
}
func (v Value) rawBytes(ptr ptr_t) []byte {
if ptr == 0 {
return nil
}
n := int32(v.c.call("sqlite3_value_bytes", v.protected()))
return util.View(v.c.mod, ptr, int64(n))
}
// Pointer gets the pointer associated with this value,
// or nil if it has no associated pointer.
func (v Value) Pointer() any {
ptr := ptr_t(v.c.call("sqlite3_value_pointer_go", v.protected()))
return util.GetHandle(v.c.ctx, ptr)
}
// JSON parses a JSON-encoded value
// and stores the result in the value pointed to by ptr.
func (v Value) JSON(ptr any) error {
var data []byte
switch v.Type() {
case NULL:
data = []byte("null")
case TEXT:
data = v.RawText()
case BLOB:
data = v.RawBlob()
case INTEGER:
data = strconv.AppendInt(nil, v.Int64(), 10)
case FLOAT:
data = util.AppendNumber(nil, v.Float())
default:
panic(util.AssertErr())
}
return json.Unmarshal(data, ptr)
}
// NoChange returns true if and only if the value is unchanged
// in a virtual table update operatiom.
//
// https://sqlite.org/c3ref/value_blob.html
func (v Value) NoChange() bool {
b := int32(v.c.call("sqlite3_value_nochange", v.protected()))
return b != 0
}
// FromBind returns true if value originated from a bound parameter.
//
// https://sqlite.org/c3ref/value_blob.html
func (v Value) FromBind() bool {
b := int32(v.c.call("sqlite3_value_frombind", v.protected()))
return b != 0
}
// InFirst returns the first element
// on the right-hand side of an IN constraint.
//
// https://sqlite.org/c3ref/vtab_in_first.html
func (v Value) InFirst() (Value, error) {
defer v.c.arena.mark()()
valPtr := v.c.arena.new(ptrlen)
rc := res_t(v.c.call("sqlite3_vtab_in_first", stk_t(v.handle), stk_t(valPtr)))
if err := v.c.error(rc); err != nil {
return Value{}, err
}
return Value{
c: v.c,
handle: util.Read32[ptr_t](v.c.mod, valPtr),
}, nil
}
// InNext returns the next element
// on the right-hand side of an IN constraint.
//
// https://sqlite.org/c3ref/vtab_in_first.html
func (v Value) InNext() (Value, error) {
defer v.c.arena.mark()()
valPtr := v.c.arena.new(ptrlen)
rc := res_t(v.c.call("sqlite3_vtab_in_next", stk_t(v.handle), stk_t(valPtr)))
if err := v.c.error(rc); err != nil {
return Value{}, err
}
return Value{
c: v.c,
handle: util.Read32[ptr_t](v.c.mod, valPtr),
}, nil
}
package adiantum
import (
"crypto/rand"
"golang.org/x/crypto/argon2"
"lukechampine.com/adiantum"
"lukechampine.com/adiantum/hbsh"
)
// This variable can be replaced with -ldflags:
//
// go build -ldflags="-X github.com/ncruces/go-sqlite3/vfs/adiantum.pepper=adiantum"
var pepper = "github.com/ncruces/go-sqlite3/vfs/adiantum"
type adiantumCreator struct{}
func (adiantumCreator) HBSH(key []byte) *hbsh.HBSH {
if len(key) != 32 {
return nil
}
return adiantum.New(key)
}
func (adiantumCreator) KDF(text string) []byte {
if text == "" {
key := make([]byte, 32)
n, _ := rand.Read(key)
return key[:n]
}
return argon2.IDKey([]byte(text), []byte(pepper), 3, 64*1024, 4, 32)
}
// Package adiantum wraps an SQLite VFS to offer encryption at rest.
//
// The "adiantum" [vfs.VFS] wraps the default VFS using the
// Adiantum tweakable, length-preserving encryption.
//
// Importing package adiantum registers that VFS:
//
// import _ "github.com/ncruces/go-sqlite3/vfs/adiantum"
//
// To open an encrypted database you need to provide key material.
//
// The simplest way to do that is to specify the key through an [URI] parameter:
//
// - key: key material in binary (32 bytes)
// - hexkey: key material in hex (64 hex digits)
// - textkey: key material in text (any length)
//
// However, this makes your key easily accessible to other parts of
// your application (e.g. through [vfs.Filename.URIParameters]).
//
// To avoid this, invoke any of the following PRAGMAs
// immediately after opening a connection:
//
// PRAGMA key='D41d8cD98f00b204e9800998eCf8427e';
// PRAGMA hexkey='e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855';
// PRAGMA textkey='your-secret-key';
//
// For an ATTACH-ed database, you must specify the schema name:
//
// ATTACH DATABASE 'demo.db' AS demo;
// PRAGMA demo.textkey='your-secret-key';
//
// [URI]: https://sqlite.org/uri.html
package adiantum
import (
"lukechampine.com/adiantum/hbsh"
"github.com/ncruces/go-sqlite3/vfs"
)
func init() {
vfs.Register("adiantum", Wrap(vfs.Find(""), nil))
}
// Wrap wraps a base VFS to create an encrypting VFS,
// possibly using a custom HBSH cipher construction.
//
// To use the default Adiantum construction, set cipher to nil.
//
// The default construction uses a 32 byte key/hexkey.
// If a textkey is provided, the default KDF is Argon2id
// with 64 MiB of memory, 3 iterations, and 4 threads.
func Wrap(base vfs.VFS, cipher HBSHCreator) vfs.VFS {
if cipher == nil {
cipher = adiantumCreator{}
}
return &hbshVFS{
VFS: base,
init: cipher,
}
}
// HBSHCreator creates an [hbsh.HBSH] cipher
// given key material.
type HBSHCreator interface {
// KDF derives an HBSH key from a secret.
// If no secret is given, a random key is generated.
KDF(secret string) (key []byte)
// HBSH creates an HBSH cipher given a key.
// If key is not appropriate, nil is returned.
HBSH(key []byte) *hbsh.HBSH
}
package adiantum
import (
"encoding/binary"
"encoding/hex"
"io"
"lukechampine.com/adiantum/hbsh"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/go-sqlite3/util/vfsutil"
"github.com/ncruces/go-sqlite3/vfs"
)
type hbshVFS struct {
vfs.VFS
init HBSHCreator
}
func (h *hbshVFS) Open(name string, flags vfs.OpenFlag) (vfs.File, vfs.OpenFlag, error) {
// notest // OpenFilename is called instead
return nil, 0, sqlite3.CANTOPEN
}
func (h *hbshVFS) OpenFilename(name *vfs.Filename, flags vfs.OpenFlag) (file vfs.File, _ vfs.OpenFlag, err error) {
file, flags, err = vfsutil.WrapOpenFilename(h.VFS, name, flags)
// Encrypt everything except super journals and memory files.
if err != nil || flags&(vfs.OPEN_SUPER_JOURNAL|vfs.OPEN_MEMORY) != 0 {
return file, flags, err
}
var hbsh *hbsh.HBSH
if f, ok := vfsutil.UnwrapFile[*hbshFile](name.DatabaseFile()); ok {
hbsh = f.hbsh
} else {
var key []byte
if params := name.URIParameters(); name == nil {
key = h.init.KDF("") // Temporary files get a random key.
} else if t, ok := params["key"]; ok {
key = []byte(t[0])
} else if t, ok := params["hexkey"]; ok {
key, _ = hex.DecodeString(t[0])
} else if t, ok := params["textkey"]; ok && len(t[0]) > 0 {
key = h.init.KDF(t[0])
} else if flags&vfs.OPEN_MAIN_DB != 0 {
// Main databases may have their key specified as a PRAGMA.
return &hbshFile{File: file, init: h.init}, flags, nil
}
hbsh = h.init.HBSH(key)
}
if hbsh == nil {
file.Close()
return nil, flags, sqlite3.CANTOPEN
}
return &hbshFile{File: file, hbsh: hbsh, init: h.init}, flags, nil
}
// Larger blocks improve both security (wide-block cipher)
// and throughput (cheap hashes amortize the block cipher's cost).
// Use the default SQLite page size;
// smaller pages pay the cost of unaligned access.
// https://sqlite.org/pgszchng2016.html
const (
tweakSize = 8
blockSize = 4096
)
// Ensure blockSize is a power of two.
var _ [0]struct{} = [blockSize & (blockSize - 1)]struct{}{}
func roundDown(i int64) int64 {
return i &^ (blockSize - 1)
}
func roundUp[T int | int64](i T) T {
return (i + (blockSize - 1)) &^ (blockSize - 1)
}
type hbshFile struct {
vfs.File
init HBSHCreator
hbsh *hbsh.HBSH
tweak [tweakSize]byte
block [blockSize]byte
}
func (h *hbshFile) Pragma(name string, value string) (string, error) {
var key []byte
switch name {
case "key":
key = []byte(value)
case "hexkey":
key, _ = hex.DecodeString(value)
case "textkey":
if len(value) > 0 {
key = h.init.KDF(value)
}
default:
return vfsutil.WrapPragma(h.File, name, value)
}
if h.hbsh = h.init.HBSH(key); h.hbsh != nil {
return "ok", nil
}
return "", sqlite3.CANTOPEN
}
func (h *hbshFile) ReadAt(p []byte, off int64) (n int, err error) {
if h.hbsh == nil {
// Only OPEN_MAIN_DB can have a missing key.
if off == 0 && len(p) == 100 {
// SQLite is trying to read the header of a database file.
// Pretend the file is empty so the key may be specified as a PRAGMA.
return 0, io.EOF
}
return 0, sqlite3.CANTOPEN
}
min := roundDown(off)
max := roundUp(off + int64(len(p)))
// Read one block at a time.
for ; min < max; min += blockSize {
m, err := h.File.ReadAt(h.block[:], min)
if m != blockSize {
return n, err
}
binary.LittleEndian.PutUint64(h.tweak[:], uint64(min))
data := h.hbsh.Decrypt(h.block[:], h.tweak[:])
if off > min {
data = data[off-min:]
}
n += copy(p[n:], data)
}
if n != len(p) {
panic(util.AssertErr())
}
return n, nil
}
func (h *hbshFile) WriteAt(p []byte, off int64) (n int, err error) {
if h.hbsh == nil {
return 0, sqlite3.READONLY
}
min := roundDown(off)
max := roundUp(off + int64(len(p)))
// Write one block at a time.
for ; min < max; min += blockSize {
binary.LittleEndian.PutUint64(h.tweak[:], uint64(min))
data := h.block[:]
if off > min || len(p[n:]) < blockSize {
// Partial block write: read-update-write.
m, err := h.File.ReadAt(h.block[:], min)
if m != blockSize {
if err != io.EOF {
return n, err
}
// Writing past the EOF.
// We're either appending an entirely new block,
// or the final block was only partially written.
// A partially written block can't be decrypted,
// and is as good as corrupt.
// Either way, zero pad the file to the next block size.
clear(data)
} else {
data = h.hbsh.Decrypt(h.block[:], h.tweak[:])
}
if off > min {
data = data[off-min:]
}
}
t := copy(data, p[n:])
h.hbsh.Encrypt(h.block[:], h.tweak[:])
m, err := h.File.WriteAt(h.block[:], min)
if m != blockSize {
return n, err
}
n += t
}
if n != len(p) {
panic(util.AssertErr())
}
return n, nil
}
func (h *hbshFile) DeviceCharacteristics() vfs.DeviceCharacteristic {
var _ [0]struct{} = [blockSize - 4096]struct{}{} // Ensure blockSize is 4K.
return h.File.DeviceCharacteristics() & (0 |
// These flags are safe:
vfs.IOCAP_ATOMIC4K |
vfs.IOCAP_IMMUTABLE |
vfs.IOCAP_SEQUENTIAL |
vfs.IOCAP_SUBPAGE_READ |
vfs.IOCAP_BATCH_ATOMIC |
vfs.IOCAP_UNDELETABLE_WHEN_OPEN)
}
func (h *hbshFile) SectorSize() int {
return util.LCM(h.File.SectorSize(), blockSize)
}
func (h *hbshFile) Truncate(size int64) error {
return h.File.Truncate(roundUp(size))
}
func (h *hbshFile) ChunkSize(size int) {
vfsutil.WrapChunkSize(h.File, roundUp(size))
}
func (h *hbshFile) SizeHint(size int64) error {
return vfsutil.WrapSizeHint(h.File, roundUp(size))
}
func (h *hbshFile) Unwrap() vfs.File {
return h.File
}
func (h *hbshFile) SharedMemory() vfs.SharedMemory {
return vfsutil.WrapSharedMemory(h.File)
}
// Wrap optional methods.
func (h *hbshFile) LockState() vfs.LockLevel {
return vfsutil.WrapLockState(h.File) // notest
}
func (h *hbshFile) PersistentWAL() bool {
return vfsutil.WrapPersistWAL(h.File) // notest
}
func (h *hbshFile) SetPersistentWAL(keepWAL bool) {
vfsutil.WrapSetPersistWAL(h.File, keepWAL) // notest
}
func (h *hbshFile) HasMoved() (bool, error) {
return vfsutil.WrapHasMoved(h.File) // notest
}
func (h *hbshFile) Overwrite() error {
return vfsutil.WrapOverwrite(h.File) // notest
}
func (h *hbshFile) SyncSuper(super string) error {
return vfsutil.WrapSyncSuper(h.File, super) // notest
}
func (h *hbshFile) CommitPhaseTwo() error {
return vfsutil.WrapCommitPhaseTwo(h.File) // notest
}
func (h *hbshFile) BeginAtomicWrite() error {
return vfsutil.WrapBeginAtomicWrite(h.File) // notest
}
func (h *hbshFile) CommitAtomicWrite() error {
return vfsutil.WrapCommitAtomicWrite(h.File) // notest
}
func (h *hbshFile) RollbackAtomicWrite() error {
return vfsutil.WrapRollbackAtomicWrite(h.File) // notest
}
func (h *hbshFile) CheckpointStart() {
vfsutil.WrapCheckpointStart(h.File) // notest
}
func (h *hbshFile) CheckpointDone() {
vfsutil.WrapCheckpointDone(h.File) // notest
}
func (h *hbshFile) BusyHandler(handler func() bool) {
vfsutil.WrapBusyHandler(h.File, handler) // notest
}
package vfs
import (
"bytes"
"context"
_ "embed"
"encoding/binary"
"strconv"
"github.com/tetratelabs/wazero/api"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/go-sqlite3/util/sql3util"
)
func cksmWrapFile(name *Filename, flags OpenFlag, file File) File {
// Checksum only main databases and WALs.
if flags&(OPEN_MAIN_DB|OPEN_WAL) == 0 {
return file
}
cksm := cksmFile{File: file}
if flags&OPEN_WAL != 0 {
main, _ := name.DatabaseFile().(cksmFile)
cksm.cksmFlags = main.cksmFlags
} else {
cksm.cksmFlags = new(cksmFlags)
cksm.isDB = true
}
return cksm
}
type cksmFile struct {
File
*cksmFlags
isDB bool
}
type cksmFlags struct {
computeCksm bool
verifyCksm bool
inCkpt bool
pageSize int
}
func (c cksmFile) ReadAt(p []byte, off int64) (n int, err error) {
n, err = c.File.ReadAt(p, off)
p = p[:n]
if isHeader(c.isDB, p, off) {
c.init((*[100]byte)(p))
}
// Verify checksums.
if c.verifyCksm && !c.inCkpt && len(p) == c.pageSize {
cksm1 := cksmCompute(p[:len(p)-8])
cksm2 := *(*[8]byte)(p[len(p)-8:])
if cksm1 != cksm2 {
return 0, _IOERR_DATA
}
}
return n, err
}
func (c cksmFile) WriteAt(p []byte, off int64) (n int, err error) {
if isHeader(c.isDB, p, off) {
c.init((*[100]byte)(p))
}
// Compute checksums.
if c.computeCksm && !c.inCkpt && len(p) == c.pageSize {
*(*[8]byte)(p[len(p)-8:]) = cksmCompute(p[:len(p)-8])
}
return c.File.WriteAt(p, off)
}
func (c cksmFile) Pragma(name string, value string) (string, error) {
switch name {
case "checksum_verification":
b, ok := sql3util.ParseBool(value)
if ok {
c.verifyCksm = b && c.computeCksm
}
if !c.verifyCksm {
return "0", nil
}
return "1", nil
case "page_size":
if c.computeCksm {
// Do not allow page size changes on a checksum database.
return strconv.Itoa(c.pageSize), nil
}
}
return "", _NOTFOUND
}
func (c cksmFile) DeviceCharacteristics() DeviceCharacteristic {
ret := c.File.DeviceCharacteristics()
if c.verifyCksm {
ret &^= IOCAP_SUBPAGE_READ
}
return ret
}
func (c cksmFile) fileControl(ctx context.Context, mod api.Module, op _FcntlOpcode, pArg ptr_t) _ErrorCode {
switch op {
case _FCNTL_CKPT_START:
c.inCkpt = true
case _FCNTL_CKPT_DONE:
c.inCkpt = false
case _FCNTL_PRAGMA:
rc := vfsFileControlImpl(ctx, mod, c, op, pArg)
if rc != _NOTFOUND {
return rc
}
}
return vfsFileControlImpl(ctx, mod, c.File, op, pArg)
}
func (f *cksmFlags) init(header *[100]byte) {
f.pageSize = 256 * int(binary.LittleEndian.Uint16(header[16:18]))
if r := header[20] == 8; r != f.computeCksm {
f.computeCksm = r
f.verifyCksm = r
}
if !sql3util.ValidPageSize(f.pageSize) {
f.computeCksm = false
f.verifyCksm = false
}
}
func isHeader(isDB bool, p []byte, off int64) bool {
check := sql3util.ValidPageSize(len(p))
if isDB {
check = off == 0 && len(p) >= 100
}
return check && bytes.HasPrefix(p, []byte("SQLite format 3\000"))
}
func cksmCompute(a []byte) (cksm [8]byte) {
var s1, s2 uint32
for len(a) >= 8 {
s1 += binary.LittleEndian.Uint32(a[0:4]) + s2
s2 += binary.LittleEndian.Uint32(a[4:8]) + s1
a = a[8:]
}
if len(a) != 0 {
panic(util.AssertErr())
}
binary.LittleEndian.PutUint32(cksm[0:4], s1)
binary.LittleEndian.PutUint32(cksm[4:8], s2)
return
}
func (c cksmFile) SharedMemory() SharedMemory {
if f, ok := c.File.(FileSharedMemory); ok {
return f.SharedMemory()
}
return nil
}
func (c cksmFile) Unwrap() File {
return c.File
}
package vfs
import "github.com/ncruces/go-sqlite3/internal/util"
const (
_MAX_NAME = 1e6 // Self-imposed limit for most NUL terminated strings.
_MAX_SQL_LENGTH = 1e9
_MAX_PATHNAME = 1024
_DEFAULT_SECTOR_SIZE = 4096
ptrlen = util.PtrLen
)
type (
stk_t = util.Stk_t
ptr_t = util.Ptr_t
)
// https://sqlite.org/rescode.html
type _ErrorCode uint32
func (e _ErrorCode) Error() string {
return util.ErrorCodeString(uint32(e))
}
const (
_OK _ErrorCode = util.OK
_ERROR _ErrorCode = util.ERROR
_PERM _ErrorCode = util.PERM
_BUSY _ErrorCode = util.BUSY
_READONLY _ErrorCode = util.READONLY
_IOERR _ErrorCode = util.IOERR
_NOTFOUND _ErrorCode = util.NOTFOUND
_FULL _ErrorCode = util.FULL
_CANTOPEN _ErrorCode = util.CANTOPEN
_IOERR_READ _ErrorCode = util.IOERR_READ
_IOERR_SHORT_READ _ErrorCode = util.IOERR_SHORT_READ
_IOERR_WRITE _ErrorCode = util.IOERR_WRITE
_IOERR_FSYNC _ErrorCode = util.IOERR_FSYNC
_IOERR_DIR_FSYNC _ErrorCode = util.IOERR_DIR_FSYNC
_IOERR_TRUNCATE _ErrorCode = util.IOERR_TRUNCATE
_IOERR_FSTAT _ErrorCode = util.IOERR_FSTAT
_IOERR_UNLOCK _ErrorCode = util.IOERR_UNLOCK
_IOERR_RDLOCK _ErrorCode = util.IOERR_RDLOCK
_IOERR_DELETE _ErrorCode = util.IOERR_DELETE
_IOERR_ACCESS _ErrorCode = util.IOERR_ACCESS
_IOERR_CHECKRESERVEDLOCK _ErrorCode = util.IOERR_CHECKRESERVEDLOCK
_IOERR_LOCK _ErrorCode = util.IOERR_LOCK
_IOERR_CLOSE _ErrorCode = util.IOERR_CLOSE
_IOERR_SHMOPEN _ErrorCode = util.IOERR_SHMOPEN
_IOERR_SHMSIZE _ErrorCode = util.IOERR_SHMSIZE
_IOERR_SHMLOCK _ErrorCode = util.IOERR_SHMLOCK
_IOERR_SHMMAP _ErrorCode = util.IOERR_SHMMAP
_IOERR_SEEK _ErrorCode = util.IOERR_SEEK
_IOERR_DELETE_NOENT _ErrorCode = util.IOERR_DELETE_NOENT
_IOERR_GETTEMPPATH _ErrorCode = util.IOERR_GETTEMPPATH
_IOERR_BEGIN_ATOMIC _ErrorCode = util.IOERR_BEGIN_ATOMIC
_IOERR_COMMIT_ATOMIC _ErrorCode = util.IOERR_COMMIT_ATOMIC
_IOERR_ROLLBACK_ATOMIC _ErrorCode = util.IOERR_ROLLBACK_ATOMIC
_IOERR_DATA _ErrorCode = util.IOERR_DATA
_IOERR_CORRUPTFS _ErrorCode = util.IOERR_CORRUPTFS
_BUSY_SNAPSHOT _ErrorCode = util.BUSY_SNAPSHOT
_CANTOPEN_FULLPATH _ErrorCode = util.CANTOPEN_FULLPATH
_CANTOPEN_ISDIR _ErrorCode = util.CANTOPEN_ISDIR
_READONLY_CANTINIT _ErrorCode = util.READONLY_CANTINIT
_READONLY_DIRECTORY _ErrorCode = util.READONLY_DIRECTORY
_OK_SYMLINK _ErrorCode = util.OK_SYMLINK
)
// OpenFlag is a flag for the [VFS] Open method.
//
// https://sqlite.org/c3ref/c_open_autoproxy.html
type OpenFlag uint32
const (
OPEN_READONLY OpenFlag = 0x00000001 /* Ok for sqlite3_open_v2() */
OPEN_READWRITE OpenFlag = 0x00000002 /* Ok for sqlite3_open_v2() */
OPEN_CREATE OpenFlag = 0x00000004 /* Ok for sqlite3_open_v2() */
OPEN_DELETEONCLOSE OpenFlag = 0x00000008 /* VFS only */
OPEN_EXCLUSIVE OpenFlag = 0x00000010 /* VFS only */
OPEN_AUTOPROXY OpenFlag = 0x00000020 /* VFS only */
OPEN_URI OpenFlag = 0x00000040 /* Ok for sqlite3_open_v2() */
OPEN_MEMORY OpenFlag = 0x00000080 /* Ok for sqlite3_open_v2() */
OPEN_MAIN_DB OpenFlag = 0x00000100 /* VFS only */
OPEN_TEMP_DB OpenFlag = 0x00000200 /* VFS only */
OPEN_TRANSIENT_DB OpenFlag = 0x00000400 /* VFS only */
OPEN_MAIN_JOURNAL OpenFlag = 0x00000800 /* VFS only */
OPEN_TEMP_JOURNAL OpenFlag = 0x00001000 /* VFS only */
OPEN_SUBJOURNAL OpenFlag = 0x00002000 /* VFS only */
OPEN_SUPER_JOURNAL OpenFlag = 0x00004000 /* VFS only */
OPEN_NOMUTEX OpenFlag = 0x00008000 /* Ok for sqlite3_open_v2() */
OPEN_FULLMUTEX OpenFlag = 0x00010000 /* Ok for sqlite3_open_v2() */
OPEN_SHAREDCACHE OpenFlag = 0x00020000 /* Ok for sqlite3_open_v2() */
OPEN_PRIVATECACHE OpenFlag = 0x00040000 /* Ok for sqlite3_open_v2() */
OPEN_WAL OpenFlag = 0x00080000 /* VFS only */
OPEN_NOFOLLOW OpenFlag = 0x01000000 /* Ok for sqlite3_open_v2() */
)
// AccessFlag is a flag for the [VFS] Access method.
//
// https://sqlite.org/c3ref/c_access_exists.html
type AccessFlag uint32
const (
ACCESS_EXISTS AccessFlag = 0
ACCESS_READWRITE AccessFlag = 1 /* Used by PRAGMA temp_store_directory */
ACCESS_READ AccessFlag = 2 /* Unused */
)
// SyncFlag is a flag for the [File] Sync method.
//
// https://sqlite.org/c3ref/c_sync_dataonly.html
type SyncFlag uint32
const (
SYNC_NORMAL SyncFlag = 0x00002
SYNC_FULL SyncFlag = 0x00003
SYNC_DATAONLY SyncFlag = 0x00010
)
// LockLevel is a value used with [File] Lock and Unlock methods.
//
// https://sqlite.org/c3ref/c_lock_exclusive.html
type LockLevel uint32
const (
// No locks are held on the database.
// The database may be neither read nor written.
// Any internally cached data is considered suspect and subject to
// verification against the database file before being used.
// Other processes can read or write the database as their own locking
// states permit.
// This is the default state.
LOCK_NONE LockLevel = 0 /* xUnlock() only */
// The database may be read but not written.
// Any number of processes can hold SHARED locks at the same time,
// hence there can be many simultaneous readers.
// But no other thread or process is allowed to write to the database file
// while one or more SHARED locks are active.
LOCK_SHARED LockLevel = 1 /* xLock() or xUnlock() */
// A RESERVED lock means that the process is planning on writing to the
// database file at some point in the future but that it is currently just
// reading from the file.
// Only a single RESERVED lock may be active at one time,
// though multiple SHARED locks can coexist with a single RESERVED lock.
// RESERVED differs from PENDING in that new SHARED locks can be acquired
// while there is a RESERVED lock.
LOCK_RESERVED LockLevel = 2 /* xLock() only */
// A PENDING lock means that the process holding the lock wants to write to
// the database as soon as possible and is just waiting on all current
// SHARED locks to clear so that it can get an EXCLUSIVE lock.
// No new SHARED locks are permitted against the database if a PENDING lock
// is active, though existing SHARED locks are allowed to continue.
LOCK_PENDING LockLevel = 3 /* internal use only */
// An EXCLUSIVE lock is needed in order to write to the database file.
// Only one EXCLUSIVE lock is allowed on the file and no other locks of any
// kind are allowed to coexist with an EXCLUSIVE lock.
// In order to maximize concurrency, SQLite works to minimize the amount of
// time that EXCLUSIVE locks are held.
LOCK_EXCLUSIVE LockLevel = 4 /* xLock() only */
)
// DeviceCharacteristic is a flag retuned by the [File] DeviceCharacteristics method.
//
// https://sqlite.org/c3ref/c_iocap_atomic.html
type DeviceCharacteristic uint32
const (
IOCAP_ATOMIC DeviceCharacteristic = 0x00000001
IOCAP_ATOMIC512 DeviceCharacteristic = 0x00000002
IOCAP_ATOMIC1K DeviceCharacteristic = 0x00000004
IOCAP_ATOMIC2K DeviceCharacteristic = 0x00000008
IOCAP_ATOMIC4K DeviceCharacteristic = 0x00000010
IOCAP_ATOMIC8K DeviceCharacteristic = 0x00000020
IOCAP_ATOMIC16K DeviceCharacteristic = 0x00000040
IOCAP_ATOMIC32K DeviceCharacteristic = 0x00000080
IOCAP_ATOMIC64K DeviceCharacteristic = 0x00000100
IOCAP_SAFE_APPEND DeviceCharacteristic = 0x00000200
IOCAP_SEQUENTIAL DeviceCharacteristic = 0x00000400
IOCAP_UNDELETABLE_WHEN_OPEN DeviceCharacteristic = 0x00000800
IOCAP_POWERSAFE_OVERWRITE DeviceCharacteristic = 0x00001000
IOCAP_IMMUTABLE DeviceCharacteristic = 0x00002000
IOCAP_BATCH_ATOMIC DeviceCharacteristic = 0x00004000
IOCAP_SUBPAGE_READ DeviceCharacteristic = 0x00008000
)
// https://sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
type _FcntlOpcode uint32
const (
_FCNTL_LOCKSTATE _FcntlOpcode = 1
_FCNTL_GET_LOCKPROXYFILE _FcntlOpcode = 2
_FCNTL_SET_LOCKPROXYFILE _FcntlOpcode = 3
_FCNTL_LAST_ERRNO _FcntlOpcode = 4
_FCNTL_SIZE_HINT _FcntlOpcode = 5
_FCNTL_CHUNK_SIZE _FcntlOpcode = 6
_FCNTL_FILE_POINTER _FcntlOpcode = 7
_FCNTL_SYNC_OMITTED _FcntlOpcode = 8
_FCNTL_WIN32_AV_RETRY _FcntlOpcode = 9
_FCNTL_PERSIST_WAL _FcntlOpcode = 10
_FCNTL_OVERWRITE _FcntlOpcode = 11
_FCNTL_VFSNAME _FcntlOpcode = 12
_FCNTL_POWERSAFE_OVERWRITE _FcntlOpcode = 13
_FCNTL_PRAGMA _FcntlOpcode = 14
_FCNTL_BUSYHANDLER _FcntlOpcode = 15
_FCNTL_TEMPFILENAME _FcntlOpcode = 16
_FCNTL_MMAP_SIZE _FcntlOpcode = 18
_FCNTL_TRACE _FcntlOpcode = 19
_FCNTL_HAS_MOVED _FcntlOpcode = 20
_FCNTL_SYNC _FcntlOpcode = 21
_FCNTL_COMMIT_PHASETWO _FcntlOpcode = 22
_FCNTL_WIN32_SET_HANDLE _FcntlOpcode = 23
_FCNTL_WAL_BLOCK _FcntlOpcode = 24
_FCNTL_ZIPVFS _FcntlOpcode = 25
_FCNTL_RBU _FcntlOpcode = 26
_FCNTL_VFS_POINTER _FcntlOpcode = 27
_FCNTL_JOURNAL_POINTER _FcntlOpcode = 28
_FCNTL_WIN32_GET_HANDLE _FcntlOpcode = 29
_FCNTL_PDB _FcntlOpcode = 30
_FCNTL_BEGIN_ATOMIC_WRITE _FcntlOpcode = 31
_FCNTL_COMMIT_ATOMIC_WRITE _FcntlOpcode = 32
_FCNTL_ROLLBACK_ATOMIC_WRITE _FcntlOpcode = 33
_FCNTL_LOCK_TIMEOUT _FcntlOpcode = 34
_FCNTL_DATA_VERSION _FcntlOpcode = 35
_FCNTL_SIZE_LIMIT _FcntlOpcode = 36
_FCNTL_CKPT_DONE _FcntlOpcode = 37
_FCNTL_RESERVE_BYTES _FcntlOpcode = 38
_FCNTL_CKPT_START _FcntlOpcode = 39
_FCNTL_EXTERNAL_READER _FcntlOpcode = 40
_FCNTL_CKSM_FILE _FcntlOpcode = 41
_FCNTL_RESET_CACHE _FcntlOpcode = 42
_FCNTL_NULL_IO _FcntlOpcode = 43
)
// https://sqlite.org/c3ref/c_shm_exclusive.html
type _ShmFlag uint32
const (
_SHM_UNLOCK _ShmFlag = 1
_SHM_LOCK _ShmFlag = 2
_SHM_SHARED _ShmFlag = 4
_SHM_EXCLUSIVE _ShmFlag = 8
_SHM_NLOCK = 8
_SHM_BASE = 120
_SHM_DMS = _SHM_BASE + _SHM_NLOCK
)
package vfs
import (
"errors"
"io"
"io/fs"
"os"
"path/filepath"
"runtime"
"syscall"
)
type vfsOS struct{}
func (vfsOS) FullPathname(path string) (string, error) {
path, err := filepath.Abs(path)
if err != nil {
return "", err
}
return path, testSymlinks(filepath.Dir(path))
}
func testSymlinks(path string) error {
p, err := filepath.EvalSymlinks(path)
if err != nil {
return err
}
if p != path {
return _OK_SYMLINK
}
return nil
}
func (vfsOS) Delete(path string, syncDir bool) error {
err := os.Remove(path)
if errors.Is(err, fs.ErrNotExist) {
return _IOERR_DELETE_NOENT
}
if err != nil {
return err
}
if isUnix && syncDir {
f, err := os.Open(filepath.Dir(path))
if err != nil {
return _OK
}
defer f.Close()
err = osSync(f, false, false)
if err != nil {
return _IOERR_DIR_FSYNC
}
}
return nil
}
func (vfsOS) Access(name string, flags AccessFlag) (bool, error) {
err := osAccess(name, flags)
if flags == ACCESS_EXISTS {
if errors.Is(err, fs.ErrNotExist) {
return false, nil
}
} else {
if errors.Is(err, fs.ErrPermission) {
return false, nil
}
}
return err == nil, err
}
func (vfsOS) Open(name string, flags OpenFlag) (File, OpenFlag, error) {
// notest // OpenFilename is called instead
return nil, 0, _CANTOPEN
}
func (vfsOS) OpenFilename(name *Filename, flags OpenFlag) (File, OpenFlag, error) {
oflags := _O_NOFOLLOW
if flags&OPEN_EXCLUSIVE != 0 {
oflags |= os.O_EXCL
}
if flags&OPEN_CREATE != 0 {
oflags |= os.O_CREATE
}
if flags&OPEN_READONLY != 0 {
oflags |= os.O_RDONLY
}
if flags&OPEN_READWRITE != 0 {
oflags |= os.O_RDWR
}
isCreate := flags&(OPEN_CREATE) != 0
isJournl := flags&(OPEN_MAIN_JOURNAL|OPEN_SUPER_JOURNAL|OPEN_WAL) != 0
var err error
var f *os.File
if name == nil {
f, err = os.CreateTemp(os.Getenv("SQLITE_TMPDIR"), "*.db")
} else {
f, err = os.OpenFile(name.String(), oflags, 0666)
}
if err != nil {
if name == nil {
return nil, flags, _IOERR_GETTEMPPATH
}
if errors.Is(err, syscall.EISDIR) {
return nil, flags, _CANTOPEN_ISDIR
}
if isCreate && isJournl && errors.Is(err, fs.ErrPermission) &&
osAccess(name.String(), ACCESS_EXISTS) != nil {
return nil, flags, _READONLY_DIRECTORY
}
return nil, flags, err
}
if modeof := name.URIParameter("modeof"); modeof != "" {
if err = osSetMode(f, modeof); err != nil {
f.Close()
return nil, flags, _IOERR_FSTAT
}
}
if isUnix && flags&OPEN_DELETEONCLOSE != 0 {
os.Remove(f.Name())
}
file := vfsFile{
File: f,
psow: true,
atomic: osBatchAtomic(f),
readOnly: flags&OPEN_READONLY != 0,
syncDir: isUnix && isCreate && isJournl,
delete: !isUnix && flags&OPEN_DELETEONCLOSE != 0,
shm: NewSharedMemory(name.String()+"-shm", flags),
}
return &file, flags, nil
}
type vfsFile struct {
*os.File
shm SharedMemory
lock LockLevel
readOnly bool
keepWAL bool
syncDir bool
atomic bool
delete bool
psow bool
}
var (
// Ensure these interfaces are implemented:
_ FileLockState = &vfsFile{}
_ FileHasMoved = &vfsFile{}
_ FileSizeHint = &vfsFile{}
_ FilePersistWAL = &vfsFile{}
_ FilePowersafeOverwrite = &vfsFile{}
)
func (f *vfsFile) Close() error {
if f.delete {
defer os.Remove(f.Name())
}
if f.shm != nil {
f.shm.Close()
}
f.Unlock(LOCK_NONE)
return f.File.Close()
}
func (f *vfsFile) ReadAt(p []byte, off int64) (n int, err error) {
return osReadAt(f.File, p, off)
}
func (f *vfsFile) WriteAt(p []byte, off int64) (n int, err error) {
return osWriteAt(f.File, p, off)
}
func (f *vfsFile) Sync(flags SyncFlag) error {
dataonly := (flags & SYNC_DATAONLY) != 0
fullsync := (flags & 0x0f) == SYNC_FULL
err := osSync(f.File, fullsync, dataonly)
if err != nil {
return err
}
if isUnix && f.syncDir {
f.syncDir = false
d, err := os.Open(filepath.Dir(f.File.Name()))
if err != nil {
return nil
}
defer d.Close()
err = osSync(d, false, false)
if err != nil {
return _IOERR_DIR_FSYNC
}
}
return nil
}
func (f *vfsFile) Size() (int64, error) {
return f.Seek(0, io.SeekEnd)
}
func (f *vfsFile) SectorSize() int {
return _DEFAULT_SECTOR_SIZE
}
func (f *vfsFile) DeviceCharacteristics() DeviceCharacteristic {
ret := IOCAP_SUBPAGE_READ
if f.atomic {
ret |= IOCAP_BATCH_ATOMIC
}
if f.psow {
ret |= IOCAP_POWERSAFE_OVERWRITE
}
if runtime.GOOS == "windows" {
ret |= IOCAP_UNDELETABLE_WHEN_OPEN
}
return ret
}
func (f *vfsFile) SizeHint(size int64) error {
return osAllocate(f.File, size)
}
func (f *vfsFile) HasMoved() (bool, error) {
if runtime.GOOS == "windows" {
return false, nil
}
fi, err := f.Stat()
if err != nil {
return false, err
}
pi, err := os.Stat(f.Name())
if errors.Is(err, fs.ErrNotExist) {
return true, nil
}
if err != nil {
return false, err
}
return !os.SameFile(fi, pi), nil
}
func (f *vfsFile) LockState() LockLevel { return f.lock }
func (f *vfsFile) PowersafeOverwrite() bool { return f.psow }
func (f *vfsFile) PersistWAL() bool { return f.keepWAL }
func (f *vfsFile) SetPowersafeOverwrite(psow bool) { f.psow = psow }
func (f *vfsFile) SetPersistWAL(keepWAL bool) { f.keepWAL = keepWAL }
package vfs
import (
"context"
"net/url"
"github.com/tetratelabs/wazero/api"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Filename is used by SQLite to pass filenames
// to the Open method of a VFS.
//
// https://sqlite.org/c3ref/filename.html
type Filename struct {
ctx context.Context
mod api.Module
zPath ptr_t
flags OpenFlag
stack [2]stk_t
}
// GetFilename is an internal API users should not call directly.
func GetFilename(ctx context.Context, mod api.Module, id ptr_t, flags OpenFlag) *Filename {
if id == 0 {
return nil
}
return &Filename{
ctx: ctx,
mod: mod,
zPath: id,
flags: flags,
}
}
// String returns this filename as a string.
func (n *Filename) String() string {
if n == nil || n.zPath == 0 {
return ""
}
return util.ReadString(n.mod, n.zPath, _MAX_PATHNAME)
}
// Database returns the name of the corresponding database file.
//
// https://sqlite.org/c3ref/filename_database.html
func (n *Filename) Database() string {
return n.path("sqlite3_filename_database")
}
// Journal returns the name of the corresponding rollback journal file.
//
// https://sqlite.org/c3ref/filename_database.html
func (n *Filename) Journal() string {
return n.path("sqlite3_filename_journal")
}
// Journal returns the name of the corresponding WAL file.
//
// https://sqlite.org/c3ref/filename_database.html
func (n *Filename) WAL() string {
return n.path("sqlite3_filename_wal")
}
func (n *Filename) path(method string) string {
if n == nil || n.zPath == 0 {
return ""
}
if n.flags&(OPEN_MAIN_DB|OPEN_MAIN_JOURNAL|OPEN_WAL) == 0 {
return ""
}
n.stack[0] = stk_t(n.zPath)
fn := n.mod.ExportedFunction(method)
if err := fn.CallWithStack(n.ctx, n.stack[:]); err != nil {
panic(err)
}
return util.ReadString(n.mod, ptr_t(n.stack[0]), _MAX_PATHNAME)
}
// DatabaseFile returns the main database [File] corresponding to a journal.
//
// https://sqlite.org/c3ref/database_file_object.html
func (n *Filename) DatabaseFile() File {
if n == nil || n.zPath == 0 {
return nil
}
if n.flags&(OPEN_MAIN_DB|OPEN_MAIN_JOURNAL|OPEN_WAL) == 0 {
return nil
}
n.stack[0] = stk_t(n.zPath)
fn := n.mod.ExportedFunction("sqlite3_database_file_object")
if err := fn.CallWithStack(n.ctx, n.stack[:]); err != nil {
panic(err)
}
file, _ := vfsFileGet(n.ctx, n.mod, ptr_t(n.stack[0])).(File)
return file
}
// URIParameter returns the value of a URI parameter.
//
// https://sqlite.org/c3ref/uri_boolean.html
func (n *Filename) URIParameter(key string) string {
if n == nil || n.zPath == 0 {
return ""
}
uriKey := n.mod.ExportedFunction("sqlite3_uri_key")
n.stack[0] = stk_t(n.zPath)
n.stack[1] = stk_t(0)
if err := uriKey.CallWithStack(n.ctx, n.stack[:]); err != nil {
panic(err)
}
ptr := ptr_t(n.stack[0])
if ptr == 0 {
return ""
}
// Parse the format from:
// https://github.com/sqlite/sqlite/blob/b74eb0/src/pager.c#L4797-L4840
// This avoids having to alloc/free the key just to find a value.
for {
k := util.ReadString(n.mod, ptr, _MAX_NAME)
if k == "" {
return ""
}
ptr += ptr_t(len(k)) + 1
v := util.ReadString(n.mod, ptr, _MAX_NAME)
if k == key {
return v
}
ptr += ptr_t(len(v)) + 1
}
}
// URIParameters obtains values for URI parameters.
//
// https://sqlite.org/c3ref/uri_boolean.html
func (n *Filename) URIParameters() url.Values {
if n == nil || n.zPath == 0 {
return nil
}
uriKey := n.mod.ExportedFunction("sqlite3_uri_key")
n.stack[0] = stk_t(n.zPath)
n.stack[1] = stk_t(0)
if err := uriKey.CallWithStack(n.ctx, n.stack[:]); err != nil {
panic(err)
}
ptr := ptr_t(n.stack[0])
if ptr == 0 {
return nil
}
var params url.Values
// Parse the format from:
// https://github.com/sqlite/sqlite/blob/b74eb0/src/pager.c#L4797-L4840
// This is the only way to support multiple valued keys.
for {
k := util.ReadString(n.mod, ptr, _MAX_NAME)
if k == "" {
return params
}
ptr += ptr_t(len(k)) + 1
v := util.ReadString(n.mod, ptr, _MAX_NAME)
if params == nil {
params = url.Values{}
}
params.Add(k, v)
ptr += ptr_t(len(v)) + 1
}
}
//go:build linux || darwin || windows || freebsd || openbsd || netbsd || dragonfly || illumos || sqlite3_flock || sqlite3_dotlk
package vfs
import "github.com/ncruces/go-sqlite3/internal/util"
// SupportsFileLocking is false on platforms that do not support file locking.
// To open a database file on those platforms,
// you need to use the [nolock] or [immutable] URI parameters.
//
// [nolock]: https://sqlite.org/uri.html#urinolock
// [immutable]: https://sqlite.org/uri.html#uriimmutable
const SupportsFileLocking = true
const (
_PENDING_BYTE = 0x40000000
_RESERVED_BYTE = (_PENDING_BYTE + 1)
_SHARED_FIRST = (_PENDING_BYTE + 2)
_SHARED_SIZE = 510
)
func (f *vfsFile) Lock(lock LockLevel) error {
switch {
case lock != LOCK_SHARED && lock != LOCK_RESERVED && lock != LOCK_EXCLUSIVE:
// Argument check. SQLite never explicitly requests a pending lock.
panic(util.AssertErr())
case f.lock < LOCK_NONE || f.lock > LOCK_EXCLUSIVE:
// Connection state check.
panic(util.AssertErr())
case f.lock == LOCK_NONE && lock > LOCK_SHARED:
// We never move from unlocked to anything higher than a shared lock.
panic(util.AssertErr())
case f.lock != LOCK_SHARED && lock == LOCK_RESERVED:
// A shared lock is always held when a reserved lock is requested.
panic(util.AssertErr())
}
// If we already have an equal or more restrictive lock, do nothing.
if f.lock >= lock {
return nil
}
// Do not allow any kind of write-lock on a read-only database.
if f.readOnly && lock >= LOCK_RESERVED {
return _IOERR_LOCK
}
switch lock {
case LOCK_SHARED:
// Must be unlocked to get SHARED.
if f.lock != LOCK_NONE {
panic(util.AssertErr())
}
if rc := osGetSharedLock(f.File); rc != _OK {
return rc
}
f.lock = LOCK_SHARED
return nil
case LOCK_RESERVED:
// Must be SHARED to get RESERVED.
if f.lock != LOCK_SHARED {
panic(util.AssertErr())
}
if rc := osGetReservedLock(f.File); rc != _OK {
return rc
}
f.lock = LOCK_RESERVED
return nil
case LOCK_EXCLUSIVE:
// Must be SHARED, RESERVED or PENDING to get EXCLUSIVE.
if f.lock <= LOCK_NONE || f.lock >= LOCK_EXCLUSIVE {
panic(util.AssertErr())
}
if rc := osGetExclusiveLock(f.File, &f.lock); rc != _OK {
return rc
}
f.lock = LOCK_EXCLUSIVE
return nil
default:
panic(util.AssertErr())
}
}
func (f *vfsFile) Unlock(lock LockLevel) error {
switch {
case lock != LOCK_NONE && lock != LOCK_SHARED:
// Argument check.
panic(util.AssertErr())
case f.lock < LOCK_NONE || f.lock > LOCK_EXCLUSIVE:
// Connection state check.
panic(util.AssertErr())
}
// If we don't have a more restrictive lock, do nothing.
if f.lock <= lock {
return nil
}
switch lock {
case LOCK_SHARED:
rc := osDowngradeLock(f.File, f.lock)
f.lock = LOCK_SHARED
return rc
case LOCK_NONE:
rc := osReleaseLock(f.File, f.lock)
f.lock = LOCK_NONE
return rc
default:
panic(util.AssertErr())
}
}
func (f *vfsFile) CheckReservedLock() (bool, error) {
// Connection state check.
if f.lock < LOCK_NONE || f.lock > LOCK_EXCLUSIVE {
panic(util.AssertErr())
}
if f.lock >= LOCK_RESERVED {
return true, nil
}
return osCheckReservedLock(f.File)
}
// Package memdb implements the "memdb" SQLite VFS.
//
// The "memdb" [vfs.VFS] allows the same in-memory database to be shared
// among multiple database connections in the same process,
// as long as the database name begins with "/".
//
// Importing package memdb registers the VFS:
//
// import _ "github.com/ncruces/go-sqlite3/vfs/memdb"
package memdb
import (
"fmt"
"net/url"
"sync"
"testing"
"github.com/ncruces/go-sqlite3/vfs"
)
func init() {
vfs.Register("memdb", memVFS{})
}
var (
memoryMtx sync.Mutex
// +checklocks:memoryMtx
memoryDBs = map[string]*memDB{}
)
// Create creates a shared memory database,
// using data as its initial contents.
// The new database takes ownership of data,
// and the caller should not use data after this call.
func Create(name string, data []byte) {
memoryMtx.Lock()
defer memoryMtx.Unlock()
db := &memDB{
refs: 1,
name: name,
size: int64(len(data)),
}
// Convert data from WAL/2 to rollback journal.
if len(data) >= 20 && (false ||
data[18] == 2 && data[19] == 2 ||
data[18] == 3 && data[19] == 3) {
data[18] = 1
data[19] = 1
}
sectors := divRoundUp(db.size, sectorSize)
db.data = make([]*[sectorSize]byte, sectors)
for i := range db.data {
sector := data[i*sectorSize:]
if len(sector) >= sectorSize {
db.data[i] = (*[sectorSize]byte)(sector)
} else {
db.data[i] = new([sectorSize]byte)
copy((*db.data[i])[:], sector)
}
}
memoryDBs[name] = db
}
// Delete deletes a shared memory database.
func Delete(name string) {
memoryMtx.Lock()
defer memoryMtx.Unlock()
delete(memoryDBs, name)
}
// TestDB creates an empty shared memory database for the test to use.
// The database is automatically deleted when the test and all its subtests complete.
// Each subsequent call to TestDB returns a unique database.
func TestDB(tb testing.TB, params ...url.Values) string {
tb.Helper()
name := fmt.Sprintf("%s_%p", tb.Name(), tb)
tb.Cleanup(func() { Delete(name) })
Create(name, nil)
p := url.Values{"vfs": {"memdb"}}
for _, v := range params {
for k, v := range v {
for _, v := range v {
p.Add(k, v)
}
}
}
return (&url.URL{
Scheme: "file",
OmitHost: true,
Path: "/" + name,
RawQuery: p.Encode(),
}).String()
}
package memdb
import (
"io"
"runtime"
"sync"
"time"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/vfs"
)
const sectorSize = 65536
// Ensure sectorSize is a multiple of 64K (the largest page size).
var _ [0]struct{} = [sectorSize & 65535]struct{}{}
type memVFS struct{}
func (memVFS) Open(name string, flags vfs.OpenFlag) (vfs.File, vfs.OpenFlag, error) {
// For simplicity, we do not support reading or writing data
// across "sector" boundaries.
//
// This is not a problem for most SQLite file types:
// - databases, which only do page aligned reads/writes;
// - temp journals, as used by the sorter, which does the same:
// https://github.com/sqlite/sqlite/blob/b74eb0/src/vdbesort.c#L409-L412
//
// We refuse to open all other file types,
// but returning OPEN_MEMORY means SQLite won't ask us to.
const types = vfs.OPEN_MAIN_DB |
vfs.OPEN_TEMP_DB |
vfs.OPEN_TEMP_JOURNAL
if flags&types == 0 {
// notest // OPEN_MEMORY
return nil, flags, sqlite3.CANTOPEN
}
// A shared database has a name that begins with "/".
shared := len(name) > 1 && name[0] == '/'
var db *memDB
if shared {
name = name[1:]
memoryMtx.Lock()
defer memoryMtx.Unlock()
db = memoryDBs[name]
}
if db == nil {
if flags&vfs.OPEN_CREATE == 0 {
return nil, flags, sqlite3.CANTOPEN
}
db = &memDB{name: name}
}
if shared {
db.refs++ // +checklocksforce: memoryMtx is held
memoryDBs[name] = db
}
return &memFile{
memDB: db,
readOnly: flags&vfs.OPEN_READONLY != 0,
}, flags | vfs.OPEN_MEMORY, nil
}
func (memVFS) Delete(name string, dirSync bool) error {
return sqlite3.IOERR_DELETE_NOENT // used to delete journals
}
func (memVFS) Access(name string, flag vfs.AccessFlag) (bool, error) {
return false, nil // used to check for journals
}
func (memVFS) FullPathname(name string) (string, error) {
return name, nil
}
type memDB struct {
name string
// +checklocks:dataMtx
data []*[sectorSize]byte
// +checklocks:dataMtx
size int64
// +checklocks:memoryMtx
refs int32
shared int32 // +checklocks:lockMtx
pending bool // +checklocks:lockMtx
reserved bool // +checklocks:lockMtx
lockMtx sync.Mutex
dataMtx sync.RWMutex
}
func (m *memDB) release() {
memoryMtx.Lock()
defer memoryMtx.Unlock()
if m.refs--; m.refs == 0 && m == memoryDBs[m.name] {
delete(memoryDBs, m.name)
}
}
type memFile struct {
*memDB
lock vfs.LockLevel
readOnly bool
}
var (
// Ensure these interfaces are implemented:
_ vfs.FileLockState = &memFile{}
_ vfs.FileSizeHint = &memFile{}
)
func (m *memFile) Close() error {
m.release()
return m.Unlock(vfs.LOCK_NONE)
}
func (m *memFile) ReadAt(b []byte, off int64) (n int, err error) {
m.dataMtx.RLock()
defer m.dataMtx.RUnlock()
if off >= m.size {
return 0, io.EOF
}
base := off / sectorSize
rest := off % sectorSize
have := int64(sectorSize)
if base == int64(len(m.data))-1 {
have = modRoundUp(m.size, sectorSize)
}
n = copy(b, (*m.data[base])[rest:have])
if n < len(b) {
// notest // assume reads are page aligned
return 0, io.ErrNoProgress
}
return n, nil
}
func (m *memFile) WriteAt(b []byte, off int64) (n int, err error) {
m.dataMtx.Lock()
defer m.dataMtx.Unlock()
base := off / sectorSize
rest := off % sectorSize
for base >= int64(len(m.data)) {
m.data = append(m.data, new([sectorSize]byte))
}
n = copy((*m.data[base])[rest:], b)
if n < len(b) {
// notest // assume writes are page aligned
return n, io.ErrShortWrite
}
if size := off + int64(len(b)); size > m.size {
m.size = size
}
return n, nil
}
func (m *memFile) Truncate(size int64) error {
m.dataMtx.Lock()
defer m.dataMtx.Unlock()
return m.truncate(size)
}
// +checklocks:m.dataMtx
func (m *memFile) truncate(size int64) error {
if size < m.size {
base := size / sectorSize
rest := size % sectorSize
if rest != 0 {
clear((*m.data[base])[rest:])
}
}
sectors := divRoundUp(size, sectorSize)
for sectors > int64(len(m.data)) {
m.data = append(m.data, new([sectorSize]byte))
}
clear(m.data[sectors:])
m.data = m.data[:sectors]
m.size = size
return nil
}
func (m *memFile) Sync(flag vfs.SyncFlag) error {
return nil
}
func (m *memFile) Size() (int64, error) {
m.dataMtx.RLock()
defer m.dataMtx.RUnlock()
return m.size, nil
}
const spinWait = 25 * time.Microsecond
func (m *memFile) Lock(lock vfs.LockLevel) error {
if m.lock >= lock {
return nil
}
if m.readOnly && lock >= vfs.LOCK_RESERVED {
return sqlite3.IOERR_LOCK
}
m.lockMtx.Lock()
defer m.lockMtx.Unlock()
switch lock {
case vfs.LOCK_SHARED:
if m.pending {
return sqlite3.BUSY
}
m.shared++
case vfs.LOCK_RESERVED:
if m.reserved {
return sqlite3.BUSY
}
m.reserved = true
case vfs.LOCK_EXCLUSIVE:
if m.lock < vfs.LOCK_PENDING {
m.lock = vfs.LOCK_PENDING
m.pending = true
}
for before := time.Now(); m.shared > 1; {
if time.Since(before) > spinWait {
return sqlite3.BUSY
}
m.lockMtx.Unlock()
runtime.Gosched()
m.lockMtx.Lock()
}
}
m.lock = lock
return nil
}
func (m *memFile) Unlock(lock vfs.LockLevel) error {
if m.lock <= lock {
return nil
}
m.lockMtx.Lock()
defer m.lockMtx.Unlock()
if m.lock >= vfs.LOCK_RESERVED {
m.reserved = false
}
if m.lock >= vfs.LOCK_PENDING {
m.pending = false
}
if lock < vfs.LOCK_SHARED {
m.shared--
}
m.lock = lock
return nil
}
func (m *memFile) CheckReservedLock() (bool, error) {
// notest // OPEN_MEMORY
if m.lock >= vfs.LOCK_RESERVED {
return true, nil
}
m.lockMtx.Lock()
defer m.lockMtx.Unlock()
return m.reserved, nil
}
func (m *memFile) SectorSize() int {
// notest // IOCAP_POWERSAFE_OVERWRITE
return sectorSize
}
func (m *memFile) DeviceCharacteristics() vfs.DeviceCharacteristic {
return vfs.IOCAP_ATOMIC |
vfs.IOCAP_SEQUENTIAL |
vfs.IOCAP_SAFE_APPEND |
vfs.IOCAP_POWERSAFE_OVERWRITE
}
func (m *memFile) SizeHint(size int64) error {
m.dataMtx.Lock()
defer m.dataMtx.Unlock()
if size > m.size {
return m.truncate(size)
}
return nil
}
func (m *memFile) LockState() vfs.LockLevel {
return m.lock
}
func divRoundUp(a, b int64) int64 {
return (a + b - 1) / b
}
func modRoundUp(a, b int64) int64 {
return b - (b-a%b)%b
}
//go:build amd64 || arm64 || riscv64
package vfs
import (
"os"
"golang.org/x/sys/unix"
)
const (
// https://godbolt.org/z/1PcK5vea3
_F2FS_IOC_START_ATOMIC_WRITE = 62721
_F2FS_IOC_COMMIT_ATOMIC_WRITE = 62722
_F2FS_IOC_ABORT_ATOMIC_WRITE = 62725
_F2FS_IOC_GET_FEATURES = 2147808524 // -2147158772
_F2FS_FEATURE_ATOMIC_WRITE = 4
)
// notest
func osBatchAtomic(file *os.File) bool {
flags, err := unix.IoctlGetInt(int(file.Fd()), _F2FS_IOC_GET_FEATURES)
return err == nil && flags&_F2FS_FEATURE_ATOMIC_WRITE != 0
}
func (f *vfsFile) BeginAtomicWrite() error {
return unix.IoctlSetInt(int(f.Fd()), _F2FS_IOC_START_ATOMIC_WRITE, 0)
}
func (f *vfsFile) CommitAtomicWrite() error {
return unix.IoctlSetInt(int(f.Fd()), _F2FS_IOC_COMMIT_ATOMIC_WRITE, 0)
}
func (f *vfsFile) RollbackAtomicWrite() error {
return unix.IoctlSetInt(int(f.Fd()), _F2FS_IOC_ABORT_ATOMIC_WRITE, 0)
}
//go:build !sqlite3_flock
package vfs
import (
"io"
"os"
"time"
"golang.org/x/sys/unix"
)
func osSync(file *os.File, _ /*fullsync*/, _ /*dataonly*/ bool) error {
// SQLite trusts Linux's fdatasync for all fsync's.
for {
err := unix.Fdatasync(int(file.Fd()))
if err != unix.EINTR {
return err
}
}
}
func osAllocate(file *os.File, size int64) error {
if size == 0 {
return nil
}
for {
err := unix.Fallocate(int(file.Fd()), 0, 0, size)
if err == unix.EOPNOTSUPP {
break
}
if err != unix.EINTR {
return err
}
}
off, err := file.Seek(0, io.SeekEnd)
if err != nil {
return err
}
if size <= off {
return nil
}
return file.Truncate(size)
}
func osReadLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode {
return osLock(file, unix.F_RDLCK, start, len, timeout, _IOERR_RDLOCK)
}
func osWriteLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode {
return osLock(file, unix.F_WRLCK, start, len, timeout, _IOERR_LOCK)
}
func osLock(file *os.File, typ int16, start, len int64, timeout time.Duration, def _ErrorCode) _ErrorCode {
lock := unix.Flock_t{
Type: typ,
Start: start,
Len: len,
}
var err error
switch {
default:
err = unix.FcntlFlock(file.Fd(), unix.F_OFD_SETLK, &lock)
case timeout < 0:
err = unix.FcntlFlock(file.Fd(), unix.F_OFD_SETLKW, &lock)
}
return osLockErrorCode(err, def)
}
func osUnlock(file *os.File, start, len int64) _ErrorCode {
lock := unix.Flock_t{
Type: unix.F_UNLCK,
Start: start,
Len: len,
}
for {
err := unix.FcntlFlock(file.Fd(), unix.F_OFD_SETLK, &lock)
if err == nil {
return _OK
}
if err != unix.EINTR {
return _IOERR_UNLOCK
}
}
}
//go:build (linux || darwin) && !(sqlite3_flock || sqlite3_dotlk)
package vfs
import (
"os"
"time"
"golang.org/x/sys/unix"
)
func osGetSharedLock(file *os.File) _ErrorCode {
// Test the PENDING lock before acquiring a new SHARED lock.
if lock, _ := osTestLock(file, _PENDING_BYTE, 1); lock == unix.F_WRLCK {
return _BUSY
}
// Acquire the SHARED lock.
return osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0)
}
func osGetReservedLock(file *os.File) _ErrorCode {
// Acquire the RESERVED lock.
return osWriteLock(file, _RESERVED_BYTE, 1, 0)
}
func osGetExclusiveLock(file *os.File, state *LockLevel) _ErrorCode {
if *state == LOCK_RESERVED {
// A PENDING lock is needed before acquiring an EXCLUSIVE lock.
if rc := osWriteLock(file, _PENDING_BYTE, 1, -1); rc != _OK {
return rc
}
*state = LOCK_PENDING
}
// Acquire the EXCLUSIVE lock.
return osWriteLock(file, _SHARED_FIRST, _SHARED_SIZE, time.Millisecond)
}
func osDowngradeLock(file *os.File, state LockLevel) _ErrorCode {
if state >= LOCK_EXCLUSIVE {
// Downgrade to a SHARED lock.
if rc := osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0); rc != _OK {
// notest // this should never happen
return _IOERR_RDLOCK
}
}
// Release the PENDING and RESERVED locks.
return osUnlock(file, _PENDING_BYTE, 2)
}
func osReleaseLock(file *os.File, _ LockLevel) _ErrorCode {
// Release all locks.
return osUnlock(file, 0, 0)
}
func osCheckReservedLock(file *os.File) (bool, _ErrorCode) {
// Test the RESERVED lock.
lock, rc := osTestLock(file, _RESERVED_BYTE, 1)
return lock == unix.F_WRLCK, rc
}
//go:build unix
package vfs
import (
"os"
"syscall"
"golang.org/x/sys/unix"
)
const (
isUnix = true
_O_NOFOLLOW = unix.O_NOFOLLOW
)
func osAccess(path string, flags AccessFlag) error {
var access uint32 // unix.F_OK
switch flags {
case ACCESS_READWRITE:
access = unix.R_OK | unix.W_OK
case ACCESS_READ:
access = unix.R_OK
}
return unix.Access(path, access)
}
func osReadAt(file *os.File, p []byte, off int64) (int, error) {
n, err := file.ReadAt(p, off)
if errno, ok := err.(unix.Errno); ok {
switch errno {
case
unix.ERANGE,
unix.EIO,
unix.ENXIO:
return n, _IOERR_CORRUPTFS
}
}
return n, err
}
func osWriteAt(file *os.File, p []byte, off int64) (int, error) {
n, err := file.WriteAt(p, off)
if errno, ok := err.(unix.Errno); ok && errno == unix.ENOSPC {
return n, _FULL
}
return n, err
}
func osSetMode(file *os.File, modeof string) error {
fi, err := os.Stat(modeof)
if err != nil {
return err
}
file.Chmod(fi.Mode())
if sys, ok := fi.Sys().(*syscall.Stat_t); ok {
file.Chown(int(sys.Uid), int(sys.Gid))
}
return nil
}
func osTestLock(file *os.File, start, len int64) (int16, _ErrorCode) {
lock := unix.Flock_t{
Type: unix.F_WRLCK,
Start: start,
Len: len,
}
for {
err := unix.FcntlFlock(file.Fd(), unix.F_GETLK, &lock)
if err == nil {
return lock.Type, _OK
}
if err != unix.EINTR {
return 0, _IOERR_CHECKRESERVEDLOCK
}
}
}
func osLockErrorCode(err error, def _ErrorCode) _ErrorCode {
if err == nil {
return _OK
}
if errno, ok := err.(unix.Errno); ok {
switch errno {
case
unix.EACCES,
unix.EAGAIN,
unix.EBUSY,
unix.EINTR,
unix.ENOLCK,
unix.EDEADLK,
unix.ETIMEDOUT:
return _BUSY
case unix.EPERM:
return _PERM
}
// notest // usually EWOULDBLOCK == EAGAIN
if errno == unix.EWOULDBLOCK && unix.EWOULDBLOCK != unix.EAGAIN {
return _BUSY
}
}
return def
}
// Package readervfs implements an SQLite VFS for immutable databases.
//
// The "reader" [vfs.VFS] permits accessing any [io.ReaderAt]
// as an immutable SQLite database.
//
// Importing package readervfs registers the VFS:
//
// import _ "github.com/ncruces/go-sqlite3/vfs/readervfs"
package readervfs
import (
"sync"
"github.com/ncruces/go-sqlite3/util/ioutil"
"github.com/ncruces/go-sqlite3/vfs"
)
func init() {
vfs.Register("reader", readerVFS{})
}
var (
readerMtx sync.RWMutex
// +checklocks:readerMtx
readerDBs = map[string]ioutil.SizeReaderAt{}
)
// Create creates an immutable database from reader.
// The caller should ensure that data from reader does not mutate,
// otherwise SQLite might return incorrect query results and/or [sqlite3.CORRUPT] errors.
func Create(name string, reader ioutil.SizeReaderAt) {
readerMtx.Lock()
defer readerMtx.Unlock()
readerDBs[name] = reader
}
// Delete deletes a shared memory database.
func Delete(name string) {
readerMtx.Lock()
defer readerMtx.Unlock()
delete(readerDBs, name)
}
package readervfs
import (
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/util/ioutil"
"github.com/ncruces/go-sqlite3/vfs"
)
type readerVFS struct{}
func (readerVFS) Open(name string, flags vfs.OpenFlag) (vfs.File, vfs.OpenFlag, error) {
if flags&vfs.OPEN_MAIN_DB == 0 {
// notest
return nil, flags, sqlite3.CANTOPEN
}
readerMtx.RLock()
defer readerMtx.RUnlock()
if ra, ok := readerDBs[name]; ok {
return readerFile{ra}, flags | vfs.OPEN_READONLY, nil
}
return nil, flags, sqlite3.CANTOPEN
}
func (readerVFS) Delete(name string, dirSync bool) error {
// notest
return sqlite3.IOERR_DELETE
}
func (readerVFS) Access(name string, flag vfs.AccessFlag) (bool, error) {
// notest
return false, nil
}
func (readerVFS) FullPathname(name string) (string, error) {
return name, nil
}
type readerFile struct{ ioutil.SizeReaderAt }
func (readerFile) Close() error {
return nil
}
func (readerFile) WriteAt(b []byte, off int64) (n int, err error) {
// notest
return 0, sqlite3.READONLY
}
func (readerFile) Truncate(size int64) error {
// notest
return sqlite3.READONLY
}
func (readerFile) Sync(flag vfs.SyncFlag) error {
// notest
return nil
}
func (readerFile) Lock(lock vfs.LockLevel) error {
// notest
return nil
}
func (readerFile) Unlock(lock vfs.LockLevel) error {
// notest
return nil
}
func (readerFile) CheckReservedLock() (bool, error) {
// notest
return false, nil
}
func (readerFile) SectorSize() int {
// notest
return 0
}
func (readerFile) DeviceCharacteristics() vfs.DeviceCharacteristic {
return vfs.IOCAP_IMMUTABLE | vfs.IOCAP_SUBPAGE_READ
}
package vfs
import "sync"
var (
// +checklocks:vfsRegistryMtx
vfsRegistry map[string]VFS
vfsRegistryMtx sync.RWMutex
)
// Find returns a VFS given its name.
// If there is no match, nil is returned.
// If name is empty, the default VFS is returned.
//
// https://sqlite.org/c3ref/vfs_find.html
func Find(name string) VFS {
if name == "" || name == "os" {
return vfsOS{}
}
vfsRegistryMtx.RLock()
defer vfsRegistryMtx.RUnlock()
return vfsRegistry[name]
}
// Register registers a VFS.
// Empty and "os" are reserved names.
//
// https://sqlite.org/c3ref/vfs_find.html
func Register(name string, vfs VFS) {
if name == "" || name == "os" {
return
}
vfsRegistryMtx.Lock()
defer vfsRegistryMtx.Unlock()
if vfsRegistry == nil {
vfsRegistry = map[string]VFS{}
}
vfsRegistry[name] = vfs
}
// Unregister unregisters a VFS.
//
// https://sqlite.org/c3ref/vfs_find.html
func Unregister(name string) {
vfsRegistryMtx.Lock()
defer vfsRegistryMtx.Unlock()
delete(vfsRegistry, name)
}
//go:build ((linux || darwin || windows || freebsd || openbsd || netbsd || dragonfly || illumos) && (386 || arm || amd64 || arm64 || riscv64 || ppc64le)) || sqlite3_flock || sqlite3_dotlk
package vfs
// SupportsSharedMemory is false on platforms that do not support shared memory.
// To use [WAL without shared-memory], you need to set [EXCLUSIVE locking mode].
//
// [WAL without shared-memory]: https://sqlite.org/wal.html#noshm
// [EXCLUSIVE locking mode]: https://sqlite.org/pragma.html#pragma_locking_mode
const SupportsSharedMemory = true
func (f *vfsFile) SharedMemory() SharedMemory { return f.shm }
// NewSharedMemory returns a shared-memory WAL-index
// backed by a file with the given path.
// It will return nil if shared-memory is not supported,
// or not appropriate for the given flags.
// Only [OPEN_MAIN_DB] databases may need a WAL-index.
// You must ensure all concurrent accesses to a database
// use shared-memory instances created with the same path.
func NewSharedMemory(path string, flags OpenFlag) SharedMemory {
if flags&OPEN_MAIN_DB == 0 || flags&(OPEN_DELETEONCLOSE|OPEN_MEMORY) != 0 {
return nil
}
return &vfsShm{path: path}
}
//go:build (linux || darwin) && (386 || arm || amd64 || arm64 || riscv64 || ppc64le) && !(sqlite3_flock || sqlite3_dotlk)
package vfs
import (
"context"
"io"
"os"
"sync"
"time"
"github.com/tetratelabs/wazero/api"
"golang.org/x/sys/unix"
"github.com/ncruces/go-sqlite3/internal/util"
)
type vfsShm struct {
*os.File
path string
regions []*util.MappedRegion
readOnly bool
fileLock bool
blocking bool
sync.Mutex
}
var _ blockingSharedMemory = &vfsShm{}
func (s *vfsShm) shmOpen() _ErrorCode {
if s.File == nil {
f, err := os.OpenFile(s.path,
os.O_RDWR|os.O_CREATE|_O_NOFOLLOW, 0666)
if err != nil {
f, err = os.OpenFile(s.path,
os.O_RDONLY|os.O_CREATE|_O_NOFOLLOW, 0666)
s.readOnly = true
}
if err != nil {
return _CANTOPEN
}
s.File = f
}
if s.fileLock {
return _OK
}
// Dead man's switch.
if lock, rc := osTestLock(s.File, _SHM_DMS, 1); rc != _OK {
return _IOERR_LOCK
} else if lock == unix.F_WRLCK {
return _BUSY
} else if lock == unix.F_UNLCK {
if s.readOnly {
return _READONLY_CANTINIT
}
// Do not use a blocking lock here.
// If the lock cannot be obtained immediately,
// it means some other connection is truncating the file.
// And after it has done so, it will not release its lock,
// but only downgrade it to a shared lock.
// So no point in blocking here.
// The call below to obtain the shared DMS lock may use a blocking lock.
if rc := osWriteLock(s.File, _SHM_DMS, 1, 0); rc != _OK {
return rc
}
if err := s.Truncate(0); err != nil {
return _IOERR_SHMOPEN
}
}
rc := osReadLock(s.File, _SHM_DMS, 1, time.Millisecond)
s.fileLock = rc == _OK
return rc
}
func (s *vfsShm) shmMap(ctx context.Context, mod api.Module, id, size int32, extend bool) (ptr_t, _ErrorCode) {
// Ensure size is a multiple of the OS page size.
if int(size)&(unix.Getpagesize()-1) != 0 {
return 0, _IOERR_SHMMAP
}
if rc := s.shmOpen(); rc != _OK {
return 0, rc
}
// Check if file is big enough.
o, err := s.Seek(0, io.SeekEnd)
if err != nil {
return 0, _IOERR_SHMSIZE
}
if n := (int64(id) + 1) * int64(size); n > o {
if !extend {
return 0, _OK
}
if s.readOnly || osAllocate(s.File, n) != nil {
return 0, _IOERR_SHMSIZE
}
}
r, err := util.MapRegion(ctx, mod, s.File, int64(id)*int64(size), size, s.readOnly)
if err != nil {
return 0, _IOERR_SHMMAP
}
s.regions = append(s.regions, r)
if s.readOnly {
return r.Ptr, _READONLY
}
return r.Ptr, _OK
}
func (s *vfsShm) shmLock(offset, n int32, flags _ShmFlag) _ErrorCode {
// Argument check.
switch {
case n <= 0:
panic(util.AssertErr())
case offset < 0 || offset+n > _SHM_NLOCK:
panic(util.AssertErr())
case n != 1 && flags&_SHM_EXCLUSIVE == 0:
panic(util.AssertErr())
}
switch flags {
case
_SHM_LOCK | _SHM_SHARED,
_SHM_LOCK | _SHM_EXCLUSIVE,
_SHM_UNLOCK | _SHM_SHARED,
_SHM_UNLOCK | _SHM_EXCLUSIVE:
//
default:
panic(util.AssertErr())
}
var timeout time.Duration
if s.blocking {
timeout = time.Millisecond
}
switch {
case flags&_SHM_UNLOCK != 0:
return osUnlock(s.File, _SHM_BASE+int64(offset), int64(n))
case flags&_SHM_SHARED != 0:
return osReadLock(s.File, _SHM_BASE+int64(offset), int64(n), timeout)
case flags&_SHM_EXCLUSIVE != 0:
return osWriteLock(s.File, _SHM_BASE+int64(offset), int64(n), timeout)
default:
panic(util.AssertErr())
}
}
func (s *vfsShm) shmUnmap(delete bool) {
if s.File == nil {
return
}
// Unmap regions.
for _, r := range s.regions {
r.Unmap()
}
s.regions = nil
// Close the file.
if delete {
os.Remove(s.path)
}
s.Close()
s.File = nil
}
func (s *vfsShm) shmBarrier() {
s.Lock()
//lint:ignore SA2001 memory barrier.
s.Unlock()
}
func (s *vfsShm) shmEnableBlocking(block bool) {
s.blocking = block
}
package vfs
import (
"context"
"crypto/rand"
"io"
"reflect"
"strings"
"time"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/go-sqlite3/util/sql3util"
"github.com/ncruces/julianday"
)
// ExportHostFunctions is an internal API users need not call directly.
//
// ExportHostFunctions registers the required VFS host functions
// with the provided env module.
func ExportHostFunctions(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
util.ExportFuncII(env, "go_vfs_find", vfsFind)
util.ExportFuncIIJ(env, "go_localtime", vfsLocaltime)
util.ExportFuncIIII(env, "go_randomness", vfsRandomness)
util.ExportFuncIII(env, "go_sleep", vfsSleep)
util.ExportFuncIII(env, "go_current_time_64", vfsCurrentTime64)
util.ExportFuncIIIII(env, "go_full_pathname", vfsFullPathname)
util.ExportFuncIIII(env, "go_delete", vfsDelete)
util.ExportFuncIIIII(env, "go_access", vfsAccess)
util.ExportFuncIIIIIII(env, "go_open", vfsOpen)
util.ExportFuncII(env, "go_close", vfsClose)
util.ExportFuncIIIIJ(env, "go_read", vfsRead)
util.ExportFuncIIIIJ(env, "go_write", vfsWrite)
util.ExportFuncIIJ(env, "go_truncate", vfsTruncate)
util.ExportFuncIII(env, "go_sync", vfsSync)
util.ExportFuncIII(env, "go_file_size", vfsFileSize)
util.ExportFuncIIII(env, "go_file_control", vfsFileControl)
util.ExportFuncII(env, "go_sector_size", vfsSectorSize)
util.ExportFuncII(env, "go_device_characteristics", vfsDeviceCharacteristics)
util.ExportFuncIII(env, "go_lock", vfsLock)
util.ExportFuncIII(env, "go_unlock", vfsUnlock)
util.ExportFuncIII(env, "go_check_reserved_lock", vfsCheckReservedLock)
util.ExportFuncIIIIII(env, "go_shm_map", vfsShmMap)
util.ExportFuncIIIII(env, "go_shm_lock", vfsShmLock)
util.ExportFuncIII(env, "go_shm_unmap", vfsShmUnmap)
util.ExportFuncVI(env, "go_shm_barrier", vfsShmBarrier)
return env
}
func vfsFind(ctx context.Context, mod api.Module, zVfsName ptr_t) uint32 {
name := util.ReadString(mod, zVfsName, _MAX_NAME)
if vfs := Find(name); vfs != nil && vfs != (vfsOS{}) {
return 1
}
return 0
}
func vfsLocaltime(ctx context.Context, mod api.Module, pTm ptr_t, t int64) _ErrorCode {
const size = 32 / 8
tm := time.Unix(t, 0)
// https://pubs.opengroup.org/onlinepubs/7908799/xsh/time.h.html
util.Write32(mod, pTm+0*size, int32(tm.Second()))
util.Write32(mod, pTm+1*size, int32(tm.Minute()))
util.Write32(mod, pTm+2*size, int32(tm.Hour()))
util.Write32(mod, pTm+3*size, int32(tm.Day()))
util.Write32(mod, pTm+4*size, int32(tm.Month()-time.January))
util.Write32(mod, pTm+5*size, int32(tm.Year()-1900))
util.Write32(mod, pTm+6*size, int32(tm.Weekday()-time.Sunday))
util.Write32(mod, pTm+7*size, int32(tm.YearDay()-1))
util.WriteBool(mod, pTm+8*size, tm.IsDST())
return _OK
}
func vfsRandomness(ctx context.Context, mod api.Module, pVfs ptr_t, nByte int32, zByte ptr_t) uint32 {
mem := util.View(mod, zByte, int64(nByte))
n, _ := rand.Reader.Read(mem)
return uint32(n)
}
func vfsSleep(ctx context.Context, mod api.Module, pVfs ptr_t, nMicro int32) _ErrorCode {
time.Sleep(time.Duration(nMicro) * time.Microsecond)
return _OK
}
func vfsCurrentTime64(ctx context.Context, mod api.Module, pVfs, piNow ptr_t) _ErrorCode {
day, nsec := julianday.Date(time.Now())
msec := day*86_400_000 + nsec/1_000_000
util.Write64(mod, piNow, msec)
return _OK
}
func vfsFullPathname(ctx context.Context, mod api.Module, pVfs, zRelative ptr_t, nFull int32, zFull ptr_t) _ErrorCode {
vfs := vfsGet(mod, pVfs)
path := util.ReadString(mod, zRelative, _MAX_PATHNAME)
path, err := vfs.FullPathname(path)
if len(path) >= int(nFull) {
return _CANTOPEN_FULLPATH
}
util.WriteString(mod, zFull, path)
return vfsErrorCode(err, _CANTOPEN_FULLPATH)
}
func vfsDelete(ctx context.Context, mod api.Module, pVfs, zPath ptr_t, syncDir int32) _ErrorCode {
vfs := vfsGet(mod, pVfs)
path := util.ReadString(mod, zPath, _MAX_PATHNAME)
err := vfs.Delete(path, syncDir != 0)
return vfsErrorCode(err, _IOERR_DELETE)
}
func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath ptr_t, flags AccessFlag, pResOut ptr_t) _ErrorCode {
vfs := vfsGet(mod, pVfs)
path := util.ReadString(mod, zPath, _MAX_PATHNAME)
ok, err := vfs.Access(path, flags)
util.WriteBool(mod, pResOut, ok)
return vfsErrorCode(err, _IOERR_ACCESS)
}
func vfsOpen(ctx context.Context, mod api.Module, pVfs, zPath, pFile ptr_t, flags OpenFlag, pOutFlags, pOutVFS ptr_t) _ErrorCode {
vfs := vfsGet(mod, pVfs)
name := GetFilename(ctx, mod, zPath, flags)
var file File
var err error
if ffs, ok := vfs.(VFSFilename); ok {
file, flags, err = ffs.OpenFilename(name, flags)
} else {
file, flags, err = vfs.Open(name.String(), flags)
}
if err != nil {
return vfsErrorCode(err, _CANTOPEN)
}
if file, ok := file.(FilePowersafeOverwrite); ok {
if b, ok := sql3util.ParseBool(name.URIParameter("psow")); ok {
file.SetPowersafeOverwrite(b)
}
}
if file, ok := file.(FileSharedMemory); ok && pOutVFS != 0 {
util.WriteBool(mod, pOutVFS, file.SharedMemory() != nil)
}
if pOutFlags != 0 {
util.Write32(mod, pOutFlags, flags)
}
file = cksmWrapFile(name, flags, file)
vfsFileRegister(ctx, mod, pFile, file)
return _OK
}
func vfsClose(ctx context.Context, mod api.Module, pFile ptr_t) _ErrorCode {
err := vfsFileClose(ctx, mod, pFile)
return vfsErrorCode(err, _IOERR_CLOSE)
}
func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf ptr_t, iAmt int32, iOfst int64) _ErrorCode {
file := vfsFileGet(ctx, mod, pFile).(File)
buf := util.View(mod, zBuf, int64(iAmt))
n, err := file.ReadAt(buf, iOfst)
if n == int(iAmt) {
return _OK
}
if err != io.EOF {
return vfsErrorCode(err, _IOERR_READ)
}
clear(buf[n:])
return _IOERR_SHORT_READ
}
func vfsWrite(ctx context.Context, mod api.Module, pFile, zBuf ptr_t, iAmt int32, iOfst int64) _ErrorCode {
file := vfsFileGet(ctx, mod, pFile).(File)
buf := util.View(mod, zBuf, int64(iAmt))
_, err := file.WriteAt(buf, iOfst)
return vfsErrorCode(err, _IOERR_WRITE)
}
func vfsTruncate(ctx context.Context, mod api.Module, pFile ptr_t, nByte int64) _ErrorCode {
file := vfsFileGet(ctx, mod, pFile).(File)
err := file.Truncate(nByte)
return vfsErrorCode(err, _IOERR_TRUNCATE)
}
func vfsSync(ctx context.Context, mod api.Module, pFile ptr_t, flags SyncFlag) _ErrorCode {
file := vfsFileGet(ctx, mod, pFile).(File)
err := file.Sync(flags)
return vfsErrorCode(err, _IOERR_FSYNC)
}
func vfsFileSize(ctx context.Context, mod api.Module, pFile, pSize ptr_t) _ErrorCode {
file := vfsFileGet(ctx, mod, pFile).(File)
size, err := file.Size()
util.Write64(mod, pSize, size)
return vfsErrorCode(err, _IOERR_SEEK)
}
func vfsLock(ctx context.Context, mod api.Module, pFile ptr_t, eLock LockLevel) _ErrorCode {
file := vfsFileGet(ctx, mod, pFile).(File)
err := file.Lock(eLock)
return vfsErrorCode(err, _IOERR_LOCK)
}
func vfsUnlock(ctx context.Context, mod api.Module, pFile ptr_t, eLock LockLevel) _ErrorCode {
file := vfsFileGet(ctx, mod, pFile).(File)
err := file.Unlock(eLock)
return vfsErrorCode(err, _IOERR_UNLOCK)
}
func vfsCheckReservedLock(ctx context.Context, mod api.Module, pFile, pResOut ptr_t) _ErrorCode {
file := vfsFileGet(ctx, mod, pFile).(File)
locked, err := file.CheckReservedLock()
util.WriteBool(mod, pResOut, locked)
return vfsErrorCode(err, _IOERR_CHECKRESERVEDLOCK)
}
func vfsFileControl(ctx context.Context, mod api.Module, pFile ptr_t, op _FcntlOpcode, pArg ptr_t) _ErrorCode {
file := vfsFileGet(ctx, mod, pFile).(File)
if file, ok := file.(fileControl); ok {
return file.fileControl(ctx, mod, op, pArg)
}
return vfsFileControlImpl(ctx, mod, file, op, pArg)
}
func vfsFileControlImpl(ctx context.Context, mod api.Module, file File, op _FcntlOpcode, pArg ptr_t) _ErrorCode {
switch op {
case _FCNTL_LOCKSTATE:
if file, ok := file.(FileLockState); ok {
if lk := file.LockState(); lk <= LOCK_EXCLUSIVE {
util.Write32(mod, pArg, lk)
return _OK
}
}
case _FCNTL_PERSIST_WAL:
if file, ok := file.(FilePersistWAL); ok {
if i := util.Read32[int32](mod, pArg); i < 0 {
util.WriteBool(mod, pArg, file.PersistWAL())
} else {
file.SetPersistWAL(i != 0)
}
return _OK
}
case _FCNTL_POWERSAFE_OVERWRITE:
if file, ok := file.(FilePowersafeOverwrite); ok {
if i := util.Read32[int32](mod, pArg); i < 0 {
util.WriteBool(mod, pArg, file.PowersafeOverwrite())
} else {
file.SetPowersafeOverwrite(i != 0)
}
return _OK
}
case _FCNTL_CHUNK_SIZE:
if file, ok := file.(FileChunkSize); ok {
size := util.Read32[int32](mod, pArg)
file.ChunkSize(int(size))
return _OK
}
case _FCNTL_SIZE_HINT:
if file, ok := file.(FileSizeHint); ok {
size := util.Read64[int64](mod, pArg)
err := file.SizeHint(size)
return vfsErrorCode(err, _IOERR_TRUNCATE)
}
case _FCNTL_HAS_MOVED:
if file, ok := file.(FileHasMoved); ok {
moved, err := file.HasMoved()
util.WriteBool(mod, pArg, moved)
return vfsErrorCode(err, _IOERR_FSTAT)
}
case _FCNTL_OVERWRITE:
if file, ok := file.(FileOverwrite); ok {
err := file.Overwrite()
return vfsErrorCode(err, _IOERR)
}
case _FCNTL_SYNC:
if file, ok := file.(FileSync); ok {
var name string
if pArg != 0 {
name = util.ReadString(mod, pArg, _MAX_PATHNAME)
}
err := file.SyncSuper(name)
return vfsErrorCode(err, _IOERR)
}
case _FCNTL_COMMIT_PHASETWO:
if file, ok := file.(FileCommitPhaseTwo); ok {
err := file.CommitPhaseTwo()
return vfsErrorCode(err, _IOERR)
}
case _FCNTL_BEGIN_ATOMIC_WRITE:
if file, ok := file.(FileBatchAtomicWrite); ok {
err := file.BeginAtomicWrite()
return vfsErrorCode(err, _IOERR_BEGIN_ATOMIC)
}
case _FCNTL_COMMIT_ATOMIC_WRITE:
if file, ok := file.(FileBatchAtomicWrite); ok {
err := file.CommitAtomicWrite()
return vfsErrorCode(err, _IOERR_COMMIT_ATOMIC)
}
case _FCNTL_ROLLBACK_ATOMIC_WRITE:
if file, ok := file.(FileBatchAtomicWrite); ok {
err := file.RollbackAtomicWrite()
return vfsErrorCode(err, _IOERR_ROLLBACK_ATOMIC)
}
case _FCNTL_CKPT_START:
if file, ok := file.(FileCheckpoint); ok {
file.CheckpointStart()
return _OK
}
case _FCNTL_CKPT_DONE:
if file, ok := file.(FileCheckpoint); ok {
file.CheckpointDone()
return _OK
}
case _FCNTL_PRAGMA:
if file, ok := file.(FilePragma); ok {
ptr := util.Read32[ptr_t](mod, pArg+1*ptrlen)
name := util.ReadString(mod, ptr, _MAX_SQL_LENGTH)
var value string
if ptr := util.Read32[ptr_t](mod, pArg+2*ptrlen); ptr != 0 {
value = util.ReadString(mod, ptr, _MAX_SQL_LENGTH)
}
out, err := file.Pragma(strings.ToLower(name), value)
ret := vfsErrorCode(err, _ERROR)
if ret == _ERROR {
out = err.Error()
}
if out != "" {
fn := mod.ExportedFunction("sqlite3_malloc64")
stack := [...]stk_t{stk_t(len(out) + 1)}
if err := fn.CallWithStack(ctx, stack[:]); err != nil {
panic(err)
}
util.Write32(mod, pArg, ptr_t(stack[0]))
util.WriteString(mod, ptr_t(stack[0]), out)
}
return ret
}
case _FCNTL_BUSYHANDLER:
if file, ok := file.(FileBusyHandler); ok {
arg := util.Read64[stk_t](mod, pArg)
fn := mod.ExportedFunction("sqlite3_invoke_busy_handler_go")
file.BusyHandler(func() bool {
stack := [...]stk_t{arg}
if err := fn.CallWithStack(ctx, stack[:]); err != nil {
panic(err)
}
return uint32(stack[0]) != 0
})
return _OK
}
case _FCNTL_LOCK_TIMEOUT:
if file, ok := file.(FileSharedMemory); ok {
if shm, ok := file.SharedMemory().(blockingSharedMemory); ok {
shm.shmEnableBlocking(util.ReadBool(mod, pArg))
return _OK
}
}
case _FCNTL_PDB:
if file, ok := file.(filePDB); ok {
file.SetDB(ctx.Value(util.ConnKey{}))
return _OK
}
}
return _NOTFOUND
}
func vfsSectorSize(ctx context.Context, mod api.Module, pFile ptr_t) uint32 {
file := vfsFileGet(ctx, mod, pFile).(File)
return uint32(file.SectorSize())
}
func vfsDeviceCharacteristics(ctx context.Context, mod api.Module, pFile ptr_t) DeviceCharacteristic {
file := vfsFileGet(ctx, mod, pFile).(File)
return file.DeviceCharacteristics()
}
func vfsShmBarrier(ctx context.Context, mod api.Module, pFile ptr_t) {
shm := vfsFileGet(ctx, mod, pFile).(FileSharedMemory).SharedMemory()
shm.shmBarrier()
}
func vfsShmMap(ctx context.Context, mod api.Module, pFile ptr_t, iRegion, szRegion, bExtend int32, pp ptr_t) _ErrorCode {
shm := vfsFileGet(ctx, mod, pFile).(FileSharedMemory).SharedMemory()
p, rc := shm.shmMap(ctx, mod, iRegion, szRegion, bExtend != 0)
util.Write32(mod, pp, p)
return rc
}
func vfsShmLock(ctx context.Context, mod api.Module, pFile ptr_t, offset, n int32, flags _ShmFlag) _ErrorCode {
shm := vfsFileGet(ctx, mod, pFile).(FileSharedMemory).SharedMemory()
return shm.shmLock(offset, n, flags)
}
func vfsShmUnmap(ctx context.Context, mod api.Module, pFile ptr_t, bDelete int32) _ErrorCode {
shm := vfsFileGet(ctx, mod, pFile).(FileSharedMemory).SharedMemory()
shm.shmUnmap(bDelete != 0)
return _OK
}
func vfsGet(mod api.Module, pVfs ptr_t) VFS {
var name string
if pVfs != 0 {
const zNameOffset = 16
ptr := util.Read32[ptr_t](mod, pVfs+zNameOffset)
name = util.ReadString(mod, ptr, _MAX_NAME)
}
if vfs := Find(name); vfs != nil {
return vfs
}
panic(util.NoVFSErr + util.ErrorString(name))
}
func vfsFileRegister(ctx context.Context, mod api.Module, pFile ptr_t, file File) {
const fileHandleOffset = 4
id := util.AddHandle(ctx, file)
util.Write32(mod, pFile+fileHandleOffset, id)
}
func vfsFileGet(ctx context.Context, mod api.Module, pFile ptr_t) any {
const fileHandleOffset = 4
id := util.Read32[ptr_t](mod, pFile+fileHandleOffset)
return util.GetHandle(ctx, id)
}
func vfsFileClose(ctx context.Context, mod api.Module, pFile ptr_t) error {
const fileHandleOffset = 4
id := util.Read32[ptr_t](mod, pFile+fileHandleOffset)
return util.DelHandle(ctx, id)
}
func vfsErrorCode(err error, def _ErrorCode) _ErrorCode {
if err == nil {
return _OK
}
switch v := reflect.ValueOf(err); v.Kind() {
case reflect.Uint8, reflect.Uint16, reflect.Uint32:
return _ErrorCode(v.Uint())
}
return def
}
package xts
import (
"crypto/aes"
"crypto/rand"
"crypto/sha512"
"golang.org/x/crypto/pbkdf2"
"golang.org/x/crypto/xts"
)
// This variable can be replaced with -ldflags:
//
// go build -ldflags="-X github.com/ncruces/go-sqlite3/vfs/xts.pepper=xts"
var pepper = "github.com/ncruces/go-sqlite3/vfs/xts"
type aesCreator struct{}
func (aesCreator) XTS(key []byte) *xts.Cipher {
c, err := xts.NewCipher(aes.NewCipher, key)
if err != nil {
return nil
}
return c
}
func (aesCreator) KDF(text string) []byte {
if text == "" {
key := make([]byte, 32)
n, _ := rand.Read(key)
return key[:n]
}
return pbkdf2.Key([]byte(text), []byte(pepper), 10_000, 32, sha512.New)
}
// Package xts wraps an SQLite VFS to offer encryption at rest.
//
// The "xts" [vfs.VFS] wraps the default VFS using the
// AES-XTS tweakable, length-preserving encryption.
//
// Importing package xts registers that VFS:
//
// import _ "github.com/ncruces/go-sqlite3/vfs/xts"
//
// To open an encrypted database you need to provide key material.
//
// The simplest way to do that is to specify the key through an [URI] parameter:
//
// - key: key material in binary (32, 48 or 64 bytes)
// - hexkey: key material in hex (64, 96 or 128 hex digits)
// - textkey: key material in text (any length)
//
// However, this makes your key easily accessible to other parts of
// your application (e.g. through [vfs.Filename.URIParameters]).
//
// To avoid this, invoke any of the following PRAGMAs
// immediately after opening a connection:
//
// PRAGMA key='D41d8cD98f00b204e9800998eCf8427e';
// PRAGMA hexkey='e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855';
// PRAGMA textkey='your-secret-key';
//
// For an ATTACH-ed database, you must specify the schema name:
//
// ATTACH DATABASE 'demo.db' AS demo;
// PRAGMA demo.textkey='your-secret-key';
//
// [URI]: https://sqlite.org/uri.html
package xts
import (
"golang.org/x/crypto/xts"
"github.com/ncruces/go-sqlite3/vfs"
)
func init() {
vfs.Register("xts", Wrap(vfs.Find(""), nil))
}
// Wrap wraps a base VFS to create an encrypting VFS,
// possibly using a custom XTS cipher construction.
//
// To use the default AES-XTS construction, set cipher to nil.
//
// The default construction uses AES-128, AES-192, or AES-256
// if the key/hexkey is 32, 48, or 64 bytes, respectively.
// If a textkey is provided, the default KDF is PBKDF2-HMAC-SHA512
// with 10,000 iterations, always producing a 32 byte key.
func Wrap(base vfs.VFS, cipher XTSCreator) vfs.VFS {
if cipher == nil {
cipher = aesCreator{}
}
return &xtsVFS{
VFS: base,
init: cipher,
}
}
// XTSCreator creates an [xts.Cipher]
// given key material.
type XTSCreator interface {
// KDF derives an XTS key from a secret.
// If no secret is given, a random key is generated.
KDF(secret string) (key []byte)
// XTS creates an XTS cipher given a key.
// If key is not appropriate, nil is returned.
XTS(key []byte) *xts.Cipher
}
package xts
import (
"encoding/hex"
"io"
"golang.org/x/crypto/xts"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/go-sqlite3/util/vfsutil"
"github.com/ncruces/go-sqlite3/vfs"
)
type xtsVFS struct {
vfs.VFS
init XTSCreator
}
func (x *xtsVFS) Open(name string, flags vfs.OpenFlag) (vfs.File, vfs.OpenFlag, error) {
// notest // OpenFilename is called instead
return nil, 0, sqlite3.CANTOPEN
}
func (x *xtsVFS) OpenFilename(name *vfs.Filename, flags vfs.OpenFlag) (file vfs.File, _ vfs.OpenFlag, err error) {
file, flags, err = vfsutil.WrapOpenFilename(x.VFS, name, flags)
// Encrypt everything except super journals and memory files.
if err != nil || flags&(vfs.OPEN_SUPER_JOURNAL|vfs.OPEN_MEMORY) != 0 {
return file, flags, err
}
var cipher *xts.Cipher
if f, ok := vfsutil.UnwrapFile[*xtsFile](name.DatabaseFile()); ok {
cipher = f.cipher
} else {
var key []byte
if params := name.URIParameters(); name == nil {
key = x.init.KDF("") // Temporary files get a random key.
} else if t, ok := params["key"]; ok {
key = []byte(t[0])
} else if t, ok := params["hexkey"]; ok {
key, _ = hex.DecodeString(t[0])
} else if t, ok := params["textkey"]; ok && len(t[0]) > 0 {
key = x.init.KDF(t[0])
} else if flags&vfs.OPEN_MAIN_DB != 0 {
// Main databases may have their key specified as a PRAGMA.
return &xtsFile{File: file, init: x.init}, flags, nil
}
cipher = x.init.XTS(key)
}
if cipher == nil {
file.Close()
return nil, flags, sqlite3.CANTOPEN
}
return &xtsFile{File: file, cipher: cipher, init: x.init}, flags, nil
}
// Larger sectors don't seem to significantly improve security,
// and don't affect perfomance.
// https://crossbowerbt.github.io/docs/crypto/pdf00086.pdf
// For flexibility, pick the minimum size of an SQLite page.
// https://sqlite.org/fileformat.html#pages
const sectorSize = 512
// Ensure sectorSize is a power of two.
var _ [0]struct{} = [sectorSize & (sectorSize - 1)]struct{}{}
func roundDown(i int64) int64 {
return i &^ (sectorSize - 1)
}
func roundUp[T int | int64](i T) T {
return (i + (sectorSize - 1)) &^ (sectorSize - 1)
}
type xtsFile struct {
vfs.File
init XTSCreator
cipher *xts.Cipher
sector [sectorSize]byte
}
func (x *xtsFile) Pragma(name string, value string) (string, error) {
var key []byte
switch name {
case "key":
key = []byte(value)
case "hexkey":
key, _ = hex.DecodeString(value)
case "textkey":
if len(value) > 0 {
key = x.init.KDF(value)
}
default:
return vfsutil.WrapPragma(x.File, name, value)
}
if x.cipher = x.init.XTS(key); x.cipher != nil {
return "ok", nil
}
return "", sqlite3.CANTOPEN
}
func (x *xtsFile) ReadAt(p []byte, off int64) (n int, err error) {
if x.cipher == nil {
// Only OPEN_MAIN_DB can have a missing key.
if off == 0 && len(p) == 100 {
// SQLite is trying to read the header of a database file.
// Pretend the file is empty so the key may be specified as a PRAGMA.
return 0, io.EOF
}
return 0, sqlite3.CANTOPEN
}
min := roundDown(off)
max := roundUp(off + int64(len(p)))
// Read one block at a time.
for ; min < max; min += sectorSize {
m, err := x.File.ReadAt(x.sector[:], min)
if m != sectorSize {
return n, err
}
sectorNum := uint64(min / sectorSize)
x.cipher.Decrypt(x.sector[:], x.sector[:], sectorNum)
data := x.sector[:]
if off > min {
data = data[off-min:]
}
n += copy(p[n:], data)
}
if n != len(p) {
panic(util.AssertErr())
}
return n, nil
}
func (x *xtsFile) WriteAt(p []byte, off int64) (n int, err error) {
if x.cipher == nil {
return 0, sqlite3.READONLY
}
min := roundDown(off)
max := roundUp(off + int64(len(p)))
// Write one block at a time.
for ; min < max; min += sectorSize {
sectorNum := uint64(min / sectorSize)
data := x.sector[:]
if off > min || len(p[n:]) < sectorSize {
// Partial block write: read-update-write.
m, err := x.File.ReadAt(x.sector[:], min)
if m != sectorSize {
if err != io.EOF {
return n, err
}
// Writing past the EOF.
// We're either appending an entirely new block,
// or the final block was only partially written.
// A partially written block can't be decrypted,
// and is as good as corrupt.
// Either way, zero pad the file to the next block size.
clear(data)
} else {
x.cipher.Decrypt(data, data, sectorNum)
}
if off > min {
data = data[off-min:]
}
}
t := copy(data, p[n:])
x.cipher.Encrypt(x.sector[:], x.sector[:], sectorNum)
m, err := x.File.WriteAt(x.sector[:], min)
if m != sectorSize {
return n, err
}
n += t
}
if n != len(p) {
panic(util.AssertErr())
}
return n, nil
}
func (x *xtsFile) DeviceCharacteristics() vfs.DeviceCharacteristic {
var _ [0]struct{} = [sectorSize - 512]struct{}{} // Ensure sectorSize is 512.
return x.File.DeviceCharacteristics() & (0 |
// These flags are safe:
vfs.IOCAP_ATOMIC512 |
vfs.IOCAP_IMMUTABLE |
vfs.IOCAP_SEQUENTIAL |
vfs.IOCAP_SUBPAGE_READ |
vfs.IOCAP_BATCH_ATOMIC |
vfs.IOCAP_UNDELETABLE_WHEN_OPEN)
}
func (x *xtsFile) SectorSize() int {
return util.LCM(x.File.SectorSize(), sectorSize)
}
func (x *xtsFile) Truncate(size int64) error {
return x.File.Truncate(roundUp(size))
}
func (x *xtsFile) ChunkSize(size int) {
vfsutil.WrapChunkSize(x.File, roundUp(size))
}
func (x *xtsFile) SizeHint(size int64) error {
return vfsutil.WrapSizeHint(x.File, roundUp(size))
}
func (x *xtsFile) Unwrap() vfs.File {
return x.File
}
func (x *xtsFile) SharedMemory() vfs.SharedMemory {
return vfsutil.WrapSharedMemory(x.File)
}
// Wrap optional methods.
func (x *xtsFile) LockState() vfs.LockLevel {
return vfsutil.WrapLockState(x.File) // notest
}
func (x *xtsFile) PersistentWAL() bool {
return vfsutil.WrapPersistWAL(x.File) // notest
}
func (x *xtsFile) SetPersistentWAL(keepWAL bool) {
vfsutil.WrapSetPersistWAL(x.File, keepWAL) // notest
}
func (x *xtsFile) HasMoved() (bool, error) {
return vfsutil.WrapHasMoved(x.File) // notest
}
func (x *xtsFile) Overwrite() error {
return vfsutil.WrapOverwrite(x.File) // notest
}
func (x *xtsFile) SyncSuper(super string) error {
return vfsutil.WrapSyncSuper(x.File, super) // notest
}
func (x *xtsFile) CommitPhaseTwo() error {
return vfsutil.WrapCommitPhaseTwo(x.File) // notest
}
func (x *xtsFile) BeginAtomicWrite() error {
return vfsutil.WrapBeginAtomicWrite(x.File) // notest
}
func (x *xtsFile) CommitAtomicWrite() error {
return vfsutil.WrapCommitAtomicWrite(x.File) // notest
}
func (x *xtsFile) RollbackAtomicWrite() error {
return vfsutil.WrapRollbackAtomicWrite(x.File) // notest
}
func (x *xtsFile) CheckpointStart() {
vfsutil.WrapCheckpointStart(x.File) // notest
}
func (x *xtsFile) CheckpointDone() {
vfsutil.WrapCheckpointDone(x.File) // notest
}
func (x *xtsFile) BusyHandler(handler func() bool) {
vfsutil.WrapBusyHandler(x.File, handler) // notest
}
package sqlite3
import (
"context"
"errors"
"reflect"
"github.com/tetratelabs/wazero/api"
"github.com/ncruces/go-sqlite3/internal/util"
)
// CreateModule registers a new virtual table module name.
// If create is nil, the virtual table is eponymous.
//
// https://sqlite.org/c3ref/create_module.html
func CreateModule[T VTab](db *Conn, name string, create, connect VTabConstructor[T]) error {
var flags int
const (
VTAB_CREATOR = 0x001
VTAB_DESTROYER = 0x002
VTAB_UPDATER = 0x004
VTAB_RENAMER = 0x008
VTAB_OVERLOADER = 0x010
VTAB_CHECKER = 0x020
VTAB_TXN = 0x040
VTAB_SAVEPOINTER = 0x080
VTAB_SHADOWTABS = 0x100
)
if create != nil {
flags |= VTAB_CREATOR
}
vtab := reflect.TypeOf(connect).Out(0)
if implements[VTabDestroyer](vtab) {
flags |= VTAB_DESTROYER
}
if implements[VTabUpdater](vtab) {
flags |= VTAB_UPDATER
}
if implements[VTabRenamer](vtab) {
flags |= VTAB_RENAMER
}
if implements[VTabOverloader](vtab) {
flags |= VTAB_OVERLOADER
}
if implements[VTabChecker](vtab) {
flags |= VTAB_CHECKER
}
if implements[VTabTxn](vtab) {
flags |= VTAB_TXN
}
if implements[VTabSavepointer](vtab) {
flags |= VTAB_SAVEPOINTER
}
if implements[VTabShadowTabler](vtab) {
flags |= VTAB_SHADOWTABS
}
var modulePtr ptr_t
defer db.arena.mark()()
namePtr := db.arena.string(name)
if connect != nil {
modulePtr = util.AddHandle(db.ctx, module[T]{create, connect})
}
rc := res_t(db.call("sqlite3_create_module_go", stk_t(db.handle),
stk_t(namePtr), stk_t(flags), stk_t(modulePtr)))
return db.error(rc)
}
func implements[T any](typ reflect.Type) bool {
var ptr *T
return typ.Implements(reflect.TypeOf(ptr).Elem())
}
// DeclareVTab declares the schema of a virtual table.
//
// https://sqlite.org/c3ref/declare_vtab.html
func (c *Conn) DeclareVTab(sql string) error {
if c.interrupt.Err() != nil {
return INTERRUPT
}
defer c.arena.mark()()
textPtr := c.arena.string(sql)
rc := res_t(c.call("sqlite3_declare_vtab", stk_t(c.handle), stk_t(textPtr)))
return c.error(rc)
}
// VTabConflictMode is a virtual table conflict resolution mode.
//
// https://sqlite.org/c3ref/c_fail.html
type VTabConflictMode uint8
const (
VTAB_ROLLBACK VTabConflictMode = 1
VTAB_IGNORE VTabConflictMode = 2
VTAB_FAIL VTabConflictMode = 3
VTAB_ABORT VTabConflictMode = 4
VTAB_REPLACE VTabConflictMode = 5
)
// VTabOnConflict determines the virtual table conflict policy.
//
// https://sqlite.org/c3ref/vtab_on_conflict.html
func (c *Conn) VTabOnConflict() VTabConflictMode {
return VTabConflictMode(c.call("sqlite3_vtab_on_conflict", stk_t(c.handle)))
}
// VTabConfigOption is a virtual table configuration option.
//
// https://sqlite.org/c3ref/c_vtab_constraint_support.html
type VTabConfigOption uint8
const (
VTAB_CONSTRAINT_SUPPORT VTabConfigOption = 1
VTAB_INNOCUOUS VTabConfigOption = 2
VTAB_DIRECTONLY VTabConfigOption = 3
VTAB_USES_ALL_SCHEMAS VTabConfigOption = 4
)
// VTabConfig configures various facets of the virtual table interface.
//
// https://sqlite.org/c3ref/vtab_config.html
func (c *Conn) VTabConfig(op VTabConfigOption, args ...any) error {
var i int32
if op == VTAB_CONSTRAINT_SUPPORT && len(args) > 0 {
if b, ok := args[0].(bool); ok && b {
i = 1
}
}
rc := res_t(c.call("sqlite3_vtab_config_go", stk_t(c.handle), stk_t(op), stk_t(i)))
return c.error(rc)
}
// VTabConstructor is a virtual table constructor function.
type VTabConstructor[T VTab] func(db *Conn, module, schema, table string, arg ...string) (T, error)
type module[T VTab] [2]VTabConstructor[T]
type vtabConstructor int
const (
xCreate vtabConstructor = 0
xConnect vtabConstructor = 1
)
// A VTab describes a particular instance of the virtual table.
// A VTab may optionally implement [io.Closer] to free resources.
//
// https://sqlite.org/c3ref/vtab.html
type VTab interface {
// https://sqlite.org/vtab.html#xbestindex
BestIndex(*IndexInfo) error
// https://sqlite.org/vtab.html#xopen
Open() (VTabCursor, error)
}
// A VTabDestroyer allows a virtual table to drop persistent state.
type VTabDestroyer interface {
VTab
// https://sqlite.org/vtab.html#sqlite3_module.xDestroy
Destroy() error
}
// A VTabUpdater allows a virtual table to be updated.
// Implementations must not retain arg.
type VTabUpdater interface {
VTab
// https://sqlite.org/vtab.html#xupdate
Update(arg ...Value) (rowid int64, err error)
}
// A VTabRenamer allows a virtual table to be renamed.
type VTabRenamer interface {
VTab
// https://sqlite.org/vtab.html#xrename
Rename(new string) error
}
// A VTabOverloader allows a virtual table to overload SQL functions.
type VTabOverloader interface {
VTab
// https://sqlite.org/vtab.html#xfindfunction
FindFunction(arg int, name string) (ScalarFunction, IndexConstraintOp)
}
// A VTabShadowTabler allows a virtual table to protect the content
// of shadow tables from being corrupted by hostile SQL.
//
// Implementing this interface signals that a virtual table named
// "mumble" reserves all table names starting with "mumble_".
type VTabShadowTabler interface {
VTab
// https://sqlite.org/vtab.html#the_xshadowname_method
ShadowTables()
}
// A VTabChecker allows a virtual table to report errors
// to the PRAGMA integrity_check and PRAGMA quick_check commands.
//
// Integrity should return an error if it finds problems in the content of the virtual table,
// but should avoid returning a (wrapped) [Error], [ErrorCode] or [ExtendedErrorCode],
// as those indicate the Integrity method itself encountered problems
// while trying to evaluate the virtual table content.
type VTabChecker interface {
VTab
// https://sqlite.org/vtab.html#xintegrity
Integrity(schema, table string, flags int) error
}
// A VTabTxn allows a virtual table to implement
// transactions with two-phase commit.
//
// Anything that is required as part of a commit that may fail
// should be performed in the Sync() callback.
// Current versions of SQLite ignore any errors
// returned by Commit() and Rollback().
type VTabTxn interface {
VTab
// https://sqlite.org/vtab.html#xBegin
Begin() error
// https://sqlite.org/vtab.html#xsync
Sync() error
// https://sqlite.org/vtab.html#xcommit
Commit() error
// https://sqlite.org/vtab.html#xrollback
Rollback() error
}
// A VTabSavepointer allows a virtual table to implement
// nested transactions.
//
// https://sqlite.org/vtab.html#xsavepoint
type VTabSavepointer interface {
VTabTxn
Savepoint(id int) error
Release(id int) error
RollbackTo(id int) error
}
// A VTabCursor describes cursors that point
// into the virtual table and are used
// to loop through the virtual table.
// A VTabCursor may optionally implement
// [io.Closer] to free resources.
// Implementations of Filter must not retain arg.
//
// https://sqlite.org/c3ref/vtab_cursor.html
type VTabCursor interface {
// https://sqlite.org/vtab.html#xfilter
Filter(idxNum int, idxStr string, arg ...Value) error
// https://sqlite.org/vtab.html#xnext
Next() error
// https://sqlite.org/vtab.html#xeof
EOF() bool
// https://sqlite.org/vtab.html#xcolumn
Column(ctx Context, n int) error
// https://sqlite.org/vtab.html#xrowid
RowID() (int64, error)
}
// An IndexInfo describes virtual table indexing information.
//
// https://sqlite.org/c3ref/index_info.html
type IndexInfo struct {
// Inputs
Constraint []IndexConstraint
OrderBy []IndexOrderBy
ColumnsUsed uint64
// Outputs
ConstraintUsage []IndexConstraintUsage
IdxNum int
IdxStr string
IdxFlags IndexScanFlag
OrderByConsumed bool
EstimatedCost float64
EstimatedRows int64
// Internal
c *Conn
handle ptr_t
}
// An IndexConstraint describes virtual table indexing constraint information.
//
// https://sqlite.org/c3ref/index_info.html
type IndexConstraint struct {
Column int
Op IndexConstraintOp
Usable bool
}
// An IndexOrderBy describes virtual table indexing order by information.
//
// https://sqlite.org/c3ref/index_info.html
type IndexOrderBy struct {
Column int
Desc bool
}
// An IndexConstraintUsage describes how virtual table indexing constraints will be used.
//
// https://sqlite.org/c3ref/index_info.html
type IndexConstraintUsage struct {
ArgvIndex int
Omit bool
}
// RHSValue returns the value of the right-hand operand of a constraint
// if the right-hand operand is known.
//
// https://sqlite.org/c3ref/vtab_rhs_value.html
func (idx *IndexInfo) RHSValue(column int) (Value, error) {
defer idx.c.arena.mark()()
valPtr := idx.c.arena.new(ptrlen)
rc := res_t(idx.c.call("sqlite3_vtab_rhs_value", stk_t(idx.handle),
stk_t(column), stk_t(valPtr)))
if err := idx.c.error(rc); err != nil {
return Value{}, err
}
return Value{
c: idx.c,
handle: util.Read32[ptr_t](idx.c.mod, valPtr),
}, nil
}
// Collation returns the name of the collation for a virtual table constraint.
//
// https://sqlite.org/c3ref/vtab_collation.html
func (idx *IndexInfo) Collation(column int) string {
ptr := ptr_t(idx.c.call("sqlite3_vtab_collation", stk_t(idx.handle),
stk_t(column)))
return util.ReadString(idx.c.mod, ptr, _MAX_NAME)
}
// Distinct determines if a virtual table query is DISTINCT.
//
// https://sqlite.org/c3ref/vtab_distinct.html
func (idx *IndexInfo) Distinct() int {
i := int32(idx.c.call("sqlite3_vtab_distinct", stk_t(idx.handle)))
return int(i)
}
// In identifies and handles IN constraints.
//
// https://sqlite.org/c3ref/vtab_in.html
func (idx *IndexInfo) In(column, handle int) bool {
b := int32(idx.c.call("sqlite3_vtab_in", stk_t(idx.handle),
stk_t(column), stk_t(handle)))
return b != 0
}
func (idx *IndexInfo) load() {
// https://sqlite.org/c3ref/index_info.html
mod := idx.c.mod
ptr := idx.handle
nConstraint := util.Read32[int32](mod, ptr+0)
idx.Constraint = make([]IndexConstraint, nConstraint)
idx.ConstraintUsage = make([]IndexConstraintUsage, nConstraint)
idx.OrderBy = make([]IndexOrderBy, util.Read32[int32](mod, ptr+8))
constraintPtr := util.Read32[ptr_t](mod, ptr+4)
constraint := idx.Constraint
for i := range idx.Constraint {
constraint[i] = IndexConstraint{
Column: int(util.Read32[int32](mod, constraintPtr+0)),
Op: util.Read[IndexConstraintOp](mod, constraintPtr+4),
Usable: util.Read[byte](mod, constraintPtr+5) != 0,
}
constraintPtr += 12
}
orderByPtr := util.Read32[ptr_t](mod, ptr+12)
orderBy := idx.OrderBy
for i := range orderBy {
orderBy[i] = IndexOrderBy{
Column: int(util.Read32[int32](mod, orderByPtr+0)),
Desc: util.Read[byte](mod, orderByPtr+4) != 0,
}
orderByPtr += 8
}
idx.EstimatedCost = util.ReadFloat64(mod, ptr+40)
idx.EstimatedRows = util.Read64[int64](mod, ptr+48)
idx.ColumnsUsed = util.Read64[uint64](mod, ptr+64)
}
func (idx *IndexInfo) save() {
// https://sqlite.org/c3ref/index_info.html
mod := idx.c.mod
ptr := idx.handle
usagePtr := util.Read32[ptr_t](mod, ptr+16)
for _, usage := range idx.ConstraintUsage {
util.Write32(mod, usagePtr+0, int32(usage.ArgvIndex))
if usage.Omit {
util.Write(mod, usagePtr+4, int8(1))
}
usagePtr += 8
}
util.Write32(mod, ptr+20, int32(idx.IdxNum))
if idx.IdxStr != "" {
util.Write32(mod, ptr+24, idx.c.newString(idx.IdxStr))
util.WriteBool(mod, ptr+28, true) // needToFreeIdxStr
}
if idx.OrderByConsumed {
util.WriteBool(mod, ptr+32, true)
}
util.WriteFloat64(mod, ptr+40, idx.EstimatedCost)
util.Write64(mod, ptr+48, idx.EstimatedRows)
util.Write32(mod, ptr+56, idx.IdxFlags)
}
// IndexConstraintOp is a virtual table constraint operator code.
//
// https://sqlite.org/c3ref/c_index_constraint_eq.html
type IndexConstraintOp uint8
const (
INDEX_CONSTRAINT_EQ IndexConstraintOp = 2
INDEX_CONSTRAINT_GT IndexConstraintOp = 4
INDEX_CONSTRAINT_LE IndexConstraintOp = 8
INDEX_CONSTRAINT_LT IndexConstraintOp = 16
INDEX_CONSTRAINT_GE IndexConstraintOp = 32
INDEX_CONSTRAINT_MATCH IndexConstraintOp = 64
INDEX_CONSTRAINT_LIKE IndexConstraintOp = 65
INDEX_CONSTRAINT_GLOB IndexConstraintOp = 66
INDEX_CONSTRAINT_REGEXP IndexConstraintOp = 67
INDEX_CONSTRAINT_NE IndexConstraintOp = 68
INDEX_CONSTRAINT_ISNOT IndexConstraintOp = 69
INDEX_CONSTRAINT_ISNOTNULL IndexConstraintOp = 70
INDEX_CONSTRAINT_ISNULL IndexConstraintOp = 71
INDEX_CONSTRAINT_IS IndexConstraintOp = 72
INDEX_CONSTRAINT_LIMIT IndexConstraintOp = 73
INDEX_CONSTRAINT_OFFSET IndexConstraintOp = 74
INDEX_CONSTRAINT_FUNCTION IndexConstraintOp = 150
)
// IndexScanFlag is a virtual table scan flag.
//
// https://sqlite.org/c3ref/c_index_scan_unique.html
type IndexScanFlag uint32
const (
INDEX_SCAN_UNIQUE IndexScanFlag = 1
)
func vtabModuleCallback(i vtabConstructor) func(_ context.Context, _ api.Module, _ ptr_t, _ int32, _, _, _ ptr_t) res_t {
return func(ctx context.Context, mod api.Module, pMod ptr_t, nArg int32, pArg, ppVTab, pzErr ptr_t) res_t {
arg := make([]reflect.Value, 1+nArg)
arg[0] = reflect.ValueOf(ctx.Value(connKey{}))
for i := range nArg {
ptr := util.Read32[ptr_t](mod, pArg+ptr_t(i)*ptrlen)
arg[i+1] = reflect.ValueOf(util.ReadString(mod, ptr, _MAX_SQL_LENGTH))
}
module := vtabGetHandle(ctx, mod, pMod)
val := reflect.ValueOf(module).Index(int(i)).Call(arg)
err, _ := val[1].Interface().(error)
if err == nil {
vtabPutHandle(ctx, mod, ppVTab, val[0].Interface())
}
return vtabError(ctx, mod, pzErr, _PTR_ERROR, err)
}
}
func vtabDisconnectCallback(ctx context.Context, mod api.Module, pVTab ptr_t) res_t {
err := vtabDelHandle(ctx, mod, pVTab)
return vtabError(ctx, mod, 0, _PTR_ERROR, err)
}
func vtabDestroyCallback(ctx context.Context, mod api.Module, pVTab ptr_t) res_t {
vtab := vtabGetHandle(ctx, mod, pVTab).(VTabDestroyer)
err := errors.Join(vtab.Destroy(), vtabDelHandle(ctx, mod, pVTab))
return vtabError(ctx, mod, 0, _PTR_ERROR, err)
}
func vtabBestIndexCallback(ctx context.Context, mod api.Module, pVTab, pIdxInfo ptr_t) res_t {
var info IndexInfo
info.handle = pIdxInfo
info.c = ctx.Value(connKey{}).(*Conn)
info.load()
vtab := vtabGetHandle(ctx, mod, pVTab).(VTab)
err := vtab.BestIndex(&info)
info.save()
return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err)
}
func vtabUpdateCallback(ctx context.Context, mod api.Module, pVTab ptr_t, nArg int32, pArg, pRowID ptr_t) res_t {
db := ctx.Value(connKey{}).(*Conn)
args := callbackArgs(db, nArg, pArg)
defer returnArgs(args)
vtab := vtabGetHandle(ctx, mod, pVTab).(VTabUpdater)
rowID, err := vtab.Update(*args...)
if err == nil {
util.Write64(mod, pRowID, rowID)
}
return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err)
}
func vtabRenameCallback(ctx context.Context, mod api.Module, pVTab, zNew ptr_t) res_t {
vtab := vtabGetHandle(ctx, mod, pVTab).(VTabRenamer)
err := vtab.Rename(util.ReadString(mod, zNew, _MAX_NAME))
return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err)
}
func vtabFindFuncCallback(ctx context.Context, mod api.Module, pVTab ptr_t, nArg int32, zName, pxFunc ptr_t) int32 {
vtab := vtabGetHandle(ctx, mod, pVTab).(VTabOverloader)
f, op := vtab.FindFunction(int(nArg), util.ReadString(mod, zName, _MAX_NAME))
if op != 0 {
var wrapper ptr_t
wrapper = util.AddHandle(ctx, func(c Context, arg ...Value) {
defer util.DelHandle(ctx, wrapper)
f(c, arg...)
})
util.Write32(mod, pxFunc, wrapper)
}
return int32(op)
}
func vtabIntegrityCallback(ctx context.Context, mod api.Module, pVTab, zSchema, zTabName ptr_t, mFlags uint32, pzErr ptr_t) res_t {
vtab := vtabGetHandle(ctx, mod, pVTab).(VTabChecker)
schema := util.ReadString(mod, zSchema, _MAX_NAME)
table := util.ReadString(mod, zTabName, _MAX_NAME)
err := vtab.Integrity(schema, table, int(mFlags))
// xIntegrity should return OK - even if it finds problems in the content of the virtual table.
// https://sqlite.org/vtab.html#xintegrity
vtabError(ctx, mod, pzErr, _PTR_ERROR, err)
_, code := errorCode(err, _OK)
return code
}
func vtabBeginCallback(ctx context.Context, mod api.Module, pVTab ptr_t) res_t {
vtab := vtabGetHandle(ctx, mod, pVTab).(VTabTxn)
err := vtab.Begin()
return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err)
}
func vtabSyncCallback(ctx context.Context, mod api.Module, pVTab ptr_t) res_t {
vtab := vtabGetHandle(ctx, mod, pVTab).(VTabTxn)
err := vtab.Sync()
return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err)
}
func vtabCommitCallback(ctx context.Context, mod api.Module, pVTab ptr_t) res_t {
vtab := vtabGetHandle(ctx, mod, pVTab).(VTabTxn)
err := vtab.Commit()
return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err)
}
func vtabRollbackCallback(ctx context.Context, mod api.Module, pVTab ptr_t) res_t {
vtab := vtabGetHandle(ctx, mod, pVTab).(VTabTxn)
err := vtab.Rollback()
return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err)
}
func vtabSavepointCallback(ctx context.Context, mod api.Module, pVTab ptr_t, id int32) res_t {
vtab := vtabGetHandle(ctx, mod, pVTab).(VTabSavepointer)
err := vtab.Savepoint(int(id))
return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err)
}
func vtabReleaseCallback(ctx context.Context, mod api.Module, pVTab ptr_t, id int32) res_t {
vtab := vtabGetHandle(ctx, mod, pVTab).(VTabSavepointer)
err := vtab.Release(int(id))
return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err)
}
func vtabRollbackToCallback(ctx context.Context, mod api.Module, pVTab ptr_t, id int32) res_t {
vtab := vtabGetHandle(ctx, mod, pVTab).(VTabSavepointer)
err := vtab.RollbackTo(int(id))
return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err)
}
func cursorOpenCallback(ctx context.Context, mod api.Module, pVTab, ppCur ptr_t) res_t {
vtab := vtabGetHandle(ctx, mod, pVTab).(VTab)
cursor, err := vtab.Open()
if err == nil {
vtabPutHandle(ctx, mod, ppCur, cursor)
}
return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err)
}
func cursorCloseCallback(ctx context.Context, mod api.Module, pCur ptr_t) res_t {
err := vtabDelHandle(ctx, mod, pCur)
return vtabError(ctx, mod, 0, _VTAB_ERROR, err)
}
func cursorFilterCallback(ctx context.Context, mod api.Module, pCur ptr_t, idxNum int32, idxStr ptr_t, nArg int32, pArg ptr_t) res_t {
db := ctx.Value(connKey{}).(*Conn)
args := callbackArgs(db, nArg, pArg)
defer returnArgs(args)
var idxName string
if idxStr != 0 {
idxName = util.ReadString(mod, idxStr, _MAX_LENGTH)
}
cursor := vtabGetHandle(ctx, mod, pCur).(VTabCursor)
err := cursor.Filter(int(idxNum), idxName, *args...)
return vtabError(ctx, mod, pCur, _CURSOR_ERROR, err)
}
func cursorEOFCallback(ctx context.Context, mod api.Module, pCur ptr_t) int32 {
cursor := vtabGetHandle(ctx, mod, pCur).(VTabCursor)
if cursor.EOF() {
return 1
}
return 0
}
func cursorNextCallback(ctx context.Context, mod api.Module, pCur ptr_t) res_t {
cursor := vtabGetHandle(ctx, mod, pCur).(VTabCursor)
err := cursor.Next()
return vtabError(ctx, mod, pCur, _CURSOR_ERROR, err)
}
func cursorColumnCallback(ctx context.Context, mod api.Module, pCur, pCtx ptr_t, n int32) res_t {
cursor := vtabGetHandle(ctx, mod, pCur).(VTabCursor)
db := ctx.Value(connKey{}).(*Conn)
err := cursor.Column(Context{db, pCtx}, int(n))
return vtabError(ctx, mod, pCur, _CURSOR_ERROR, err)
}
func cursorRowIDCallback(ctx context.Context, mod api.Module, pCur, pRowID ptr_t) res_t {
cursor := vtabGetHandle(ctx, mod, pCur).(VTabCursor)
rowID, err := cursor.RowID()
if err == nil {
util.Write64(mod, pRowID, rowID)
}
return vtabError(ctx, mod, pCur, _CURSOR_ERROR, err)
}
const (
_PTR_ERROR = iota
_VTAB_ERROR
_CURSOR_ERROR
)
func vtabError(ctx context.Context, mod api.Module, ptr ptr_t, kind uint32, err error) res_t {
const zErrMsgOffset = 8
msg, code := errorCode(err, ERROR)
if msg != "" && ptr != 0 {
switch kind {
case _VTAB_ERROR:
ptr = ptr + zErrMsgOffset // zErrMsg
case _CURSOR_ERROR:
ptr = util.Read32[ptr_t](mod, ptr) + zErrMsgOffset // pVTab->zErrMsg
}
db := ctx.Value(connKey{}).(*Conn)
if ptr := util.Read32[ptr_t](mod, ptr); ptr != 0 {
db.free(ptr)
}
util.Write32(mod, ptr, db.newString(msg))
}
return code
}
func vtabGetHandle(ctx context.Context, mod api.Module, ptr ptr_t) any {
const handleOffset = 4
handle := util.Read32[ptr_t](mod, ptr-handleOffset)
return util.GetHandle(ctx, handle)
}
func vtabDelHandle(ctx context.Context, mod api.Module, ptr ptr_t) error {
const handleOffset = 4
handle := util.Read32[ptr_t](mod, ptr-handleOffset)
return util.DelHandle(ctx, handle)
}
func vtabPutHandle(ctx context.Context, mod api.Module, pptr ptr_t, val any) {
const handleOffset = 4
handle := util.AddHandle(ctx, val)
ptr := util.Read32[ptr_t](mod, pptr)
util.Write32(mod, ptr-handleOffset, handle)
}