package trdsql
import (
"fmt"
"io"
"log"
"os"
"regexp"
"strings"
"github.com/jwalton/gchalk"
"github.com/olekukonko/tablewriter"
)
// AnalyzeOpts represents the options for the operation of Analyze.
type AnalyzeOpts struct {
// Command is string of the execution command.
Command string
// Quote is the quote character(s) that varies depending on the sql driver.
Quote string
// Detail is outputs detailed information.
Detail bool
// OutStream is the output destination.
OutStream io.Writer
}
// Defined to wrap string styling.
var (
colorTable = gchalk.Yellow
colorFileType = gchalk.Red
colorCaption = gchalk.Cyan
colorNotes = gchalk.Magenta
)
// NewAnalyzeOpts returns AnalyzeOpts.
func NewAnalyzeOpts() *AnalyzeOpts {
return &AnalyzeOpts{
Command: AppName,
Quote: "\\`",
Detail: true,
OutStream: os.Stdout,
}
}
// Analyze analyzes the file and outputs the table information.
// In addition, SQL execution examples are output.
func Analyze(fileName string, opts *AnalyzeOpts, readOpts *ReadOpts) error {
w := opts.OutStream
rOpts, fileName := GuessOpts(readOpts, fileName)
file, err := importFileOpen(fileName)
if err != nil {
return err
}
tableName := fileName
if rOpts.InJQuery != "" {
tableName = fileName + "::" + rOpts.InJQuery
}
defer func() {
if deferr := file.Close(); deferr != nil {
log.Printf("file close:%s", deferr)
}
}()
reader, err := NewReader(file, rOpts)
if err != nil {
return err
}
columnNames, err := reader.Names()
if err != nil {
return err
}
names := quoteNames(columnNames, opts.Quote)
columnTypes, err := reader.Types()
if err != nil {
return err
}
results := getResults(reader, len(names))
if opts.Detail {
fmt.Fprintf(w, "The table name is %s.\n", colorTable(tableName))
fmt.Fprintf(w, "The file type is %s.\n", colorFileType(rOpts.realFormat.String()))
if len(names) <= 1 && len(results) != 0 {
additionalAdvice(w, rOpts, columnNames[0], results[0][0])
}
fmt.Fprintln(w, colorCaption("\nData types:"))
typeTableRender(w, names, columnTypes)
fmt.Fprintln(w, colorCaption("\nData samples:"))
sampleTableRender(w, names, results)
fmt.Fprintln(w, colorCaption("\nExamples:"))
}
if len(results) == 0 {
return nil
}
queries := examples(tableName, names, results[0])
for _, query := range queries {
fmt.Fprintf(w, "%s %s\n", opts.Command, `"`+query+`"`)
}
return nil
}
func typeTableRender(w io.Writer, names []string, columnTypes []string) {
typeTable := tablewriter.NewWriter(w)
typeTable.SetAutoFormatHeaders(false)
typeTable.SetHeader([]string{"column name", "type"})
for i := range names {
typeTable.Append([]string{names[i], columnTypes[i]})
}
typeTable.Render()
}
func sampleTableRender(w io.Writer, names []string, results [][]string) {
sampleTable := tablewriter.NewWriter(w)
sampleTable.SetAutoFormatHeaders(false)
sampleTable.SetHeader(names)
for _, row := range results {
sampleTable.Append(row)
}
sampleTable.Render()
}
func additionalAdvice(w io.Writer, rOpts *ReadOpts, name string, value string) {
switch rOpts.realFormat {
case CSV:
checkCSV(w, value)
case JSON:
checkJSON(w, rOpts.InJQuery, name)
}
}
func checkCSV(w io.Writer, value string) {
if value == "[" || value == "{" {
fmt.Fprintln(w, colorNotes("Is it a JSON file?"))
fmt.Fprintln(w, colorNotes("Please try again with -ijson."))
return
}
fmt.Fprintln(w, colorNotes("Is the delimiter different?"))
delimiter := " "
if strings.Count(value, ";") > 1 {
delimiter = ";"
}
if strings.Count(value, "\t") > 1 {
delimiter = "\\t"
}
fmt.Fprintf(w, colorNotes("Please try again with -id \"%s\" or other character.\n"), delimiter)
if strings.Contains(value, ":") {
fmt.Fprintln(w, colorNotes("Is it a LTSV file?"))
fmt.Fprintln(w, colorNotes("Please try again with -iltsv."))
}
}
func checkJSON(w io.Writer, jquery string, name string) {
fmt.Fprintln(w, colorNotes("Is it for internal objects?"))
jq := "." + name
if jquery != "" {
jq = jquery + jq
}
fmt.Fprintf(w, colorNotes("Please try again with -ijq \"%s\".\n"), jq)
}
func quoteNames(names []string, quote string) []string {
qnames := make([]string, len(names))
for i := range names {
qnames[i] = quoted(names[i], quote)
}
return qnames
}
var noQuoteRegexp = regexp.MustCompile(`^[a-z0-9_]+$`)
func quoted(name string, quote string) string {
if noQuoteRegexp.MatchString(name) {
_, exist := keywords[name]
if !exist {
return name
}
}
return quote + name + quote
}
func getResults(reader Reader, colNum int) [][]string {
results := make([][]string, 0)
for _, row := range reader.PreReadRow() {
resultRow := make([]string, colNum)
for j, col := range row {
resultRow[j] = ValString(col)
}
results = append(results, resultRow)
}
return results
}
func examples(tableName string, names []string, results []string) []string {
queries := []string{
// #nosec G201
fmt.Sprintf("SELECT %s FROM %s", strings.Join(names, ", "), tableName),
// #nosec G201
fmt.Sprintf("SELECT %s FROM %s WHERE %s = '%s'", strings.Join(names, ", "), tableName, names[0], results[0]),
// #nosec G201
fmt.Sprintf("SELECT %s, count(%s) FROM %s GROUP BY %s", names[0], names[0], tableName, names[0]),
// #nosec G201
fmt.Sprintf("SELECT %s FROM %s ORDER BY %s LIMIT 10", strings.Join(names, ", "), tableName, names[0]),
}
return queries
}
package cmd
import (
"compress/gzip"
"context"
"database/sql"
"flag"
"fmt"
"io"
"log"
"os"
"path/filepath"
"regexp"
"strings"
"github.com/dsnet/compress/bzip2"
"github.com/jwalton/gchalk"
"github.com/klauspost/compress/zstd"
"github.com/noborus/trdsql"
"github.com/pierrec/lz4/v4"
"github.com/ulikunitz/xz"
)
// TableQuery is a query to use instead of TABLE.
const TableQuery = "SELECT * FROM"
// Cli wraps stdout and error output specification.
type Cli struct {
// OutStream is the output destination.
OutStream io.Writer
// ErrStream is the error output destination.
ErrStream io.Writer
}
// Debug represents a flag for detailed output.
var Debug bool
// The nilString structure represents a string
// that distinguishes between empty strings and nil.
type nilString struct {
str string
valid bool
}
// String returns a string.
// nilString fills the flag#value interface.
func (v *nilString) String() string {
return v.str
}
// Set sets the string with the valid flag set to true.
// nilString fills the flag#value interface.
func (v *nilString) Set(s string) error {
v.str = s
v.valid = true
return nil
}
// Run executes the main routine.
// The return value is the exit code.
func (cli Cli) Run(args []string) int {
var (
usage bool
version bool
dbList bool
config string
cDB string
cDriver string
cDSN string
guess bool
queryFile string
analyze string
onlySQL string
tableName string
inFlag inputFlag
inDelimiter string
inHeader bool
inSkip int
inPreRead int
inJQuery string
inLimitRead int
inNull nilString
inRowNumber bool
outFlag outputFlag
outFile string
outWithoutGuess bool
outDelimiter string
outQuote string
outCompression string
outAllQuotes bool
outUseCRLF bool
outHeader bool
outNoWrap bool
outNull nilString
)
flags := flag.NewFlagSet(trdsql.AppName, flag.ExitOnError)
flags.SetOutput(cli.ErrStream)
log.SetOutput(cli.ErrStream)
flags.Usage = func() { Usage(flags) }
flags.StringVar(&config, "config", config, "configuration file location.")
flags.StringVar(&cDB, "db", "", "specify db name of the setting.")
flags.BoolVar(&dbList, "dblist", false, "display db information.")
flags.StringVar(&cDriver, "driver", "", "database driver. [ "+strings.Join(sql.Drivers(), " | ")+" ]")
flags.StringVar(&cDSN, "dsn", "", "database driver specific data source name.")
flags.BoolVar(&guess, "ig", true, "guess format from extension.")
flags.StringVar(&queryFile, "q", "", "read query from the specified file.")
flags.StringVar(&analyze, "a", "", "analyze the file and suggest SQL.")
flags.StringVar(&onlySQL, "A", "", "analyze the file but only suggest SQL.")
flags.StringVar(&tableName, "t", "", "read table name from the specified file.")
flags.BoolVar(&usage, "help", false, "display usage information.")
flags.BoolVar(&version, "version", false, "display version information.")
flags.BoolVar(&Debug, "debug", false, "debug print.")
flags.StringVar(&inDelimiter, "id", ",", "field delimiter for input.")
flags.BoolVar(&inHeader, "ih", false, "the first line is interpreted as column names(CSV only).")
flags.IntVar(&inSkip, "is", 0, "skip header row.")
flags.IntVar(&inPreRead, "ir", 1, "number of rows to preread.")
flags.IntVar(&inLimitRead, "ilr", 0, "limited number of rows to read.")
flags.StringVar(&inJQuery, "ijq", "", "jq expression string for input(JSON/JSONL only).")
flags.Var(&inNull, "inull", "value(string) to convert to null on input.")
flags.BoolVar(&inRowNumber, "inum", false, "add row number column.")
flags.BoolVar(&inFlag.CSV, "icsv", false, "CSV format for input.")
flags.BoolVar(&inFlag.LTSV, "iltsv", false, "LTSV format for input.")
flags.BoolVar(&inFlag.JSON, "ijson", false, "JSON format for input.")
flags.BoolVar(&inFlag.YAML, "iyaml", false, "YAML format for input.")
flags.BoolVar(&inFlag.TBLN, "itbln", false, "TBLN format for input.")
flags.BoolVar(&inFlag.WIDTH, "iwidth", false, "width specification format for input.")
flags.BoolVar(&inFlag.TEXT, "itext", false, "text format for input.")
flags.StringVar(&outFile, "out", "", "output file name.")
flags.BoolVar(&outWithoutGuess, "out-without-guess", false, "output without guessing (when using -out).")
flags.StringVar(&outDelimiter, "od", ",", "field delimiter for output.")
flags.StringVar(&outQuote, "oq", "\"", "quote character for output.")
flags.BoolVar(&outAllQuotes, "oaq", false, "enclose all fields in quotes for output.")
flags.BoolVar(&outUseCRLF, "ocrlf", false, "use CRLF for output. End each output line with '\\r\\n' instead of '\\n'.")
flags.BoolVar(&outNoWrap, "onowrap", false, "do not wrap long lines(at/md only).")
flags.BoolVar(&outHeader, "oh", false, "output column name as header.")
flags.StringVar(&outCompression, "oz", "", "output compression format. [ gz | bz2 | zstd | lz4 | xz ]")
flags.Var(&outNull, "onull", "value(string) to convert from null on output.")
flags.BoolVar(&outFlag.CSV, "ocsv", false, "CSV format for output.")
flags.BoolVar(&outFlag.LTSV, "oltsv", false, "LTSV format for output.")
flags.BoolVar(&outFlag.AT, "oat", false, "ASCII Table format for output.")
flags.BoolVar(&outFlag.MD, "omd", false, "Markdown format for output.")
flags.BoolVar(&outFlag.VF, "ovf", false, "Vertical format for output.")
flags.BoolVar(&outFlag.RAW, "oraw", false, "Raw format for output.")
flags.BoolVar(&outFlag.JSON, "ojson", false, "JSON format for output.")
flags.BoolVar(&outFlag.TBLN, "otbln", false, "TBLN format for output.")
flags.BoolVar(&outFlag.JSONL, "ojsonl", false, "JSON lines format for output.")
flags.BoolVar(&outFlag.YAML, "oyaml", false, "YAML format for output.")
if err := flags.Parse(args[1:]); err != nil {
log.Printf("ERROR: %s", err)
return 1
}
if version {
fmt.Fprintf(cli.OutStream, "%s version %s\n", trdsql.AppName, trdsql.Version)
return 0
}
if Debug {
trdsql.EnableDebug()
}
// MultipleQueries is enabled by default.
trdsql.EnableMultipleQueries()
cfgFile := configOpen(config)
cfg, err := loadConfig(cfgFile)
if err != nil && config != "" {
log.Printf("ERROR: [%s]%s", config, err)
return 1
}
if dbList {
printDBList(cli.OutStream, cfg)
return 0
}
driver, dsn := getDB(cfg, cDB, cDriver, cDSN)
if analyze != "" || onlySQL != "" {
opts := trdsql.NewAnalyzeOpts()
opts.OutStream = cli.OutStream
opts = quoteOpts(opts, driver)
if onlySQL != "" {
analyze = onlySQL
opts.Detail = false
}
opts = optsCommand(opts, os.Args)
if inHeader && inPreRead == 1 {
inPreRead = 2
}
readOpts := trdsql.NewReadOpts(
trdsql.InFormat(inputFormat(inFlag)),
trdsql.InDelimiter(inDelimiter),
trdsql.InHeader(inHeader),
trdsql.InSkip(inSkip),
trdsql.InPreRead(inPreRead),
trdsql.InJQ(inJQuery),
)
if err = trdsql.Analyze(analyze, opts, readOpts); err != nil {
log.Printf("ERROR: %s", err)
return 1
}
return 0
}
query, err := getQuery(flags.Args(), tableName, queryFile)
if err != nil {
log.Printf("ERROR: %s", err)
return 1
}
if usage || (len(query) == 0) {
Usage(flags)
return 2
}
preRead := inPreRead
limitRead := false
if inLimitRead > 0 {
limitRead = true
preRead = inLimitRead
if inSkip > 0 {
preRead += inSkip
}
if inHeader {
preRead++
}
}
importer := trdsql.NewImporter(
trdsql.InFormat(inputFormat(inFlag)),
trdsql.InDelimiter(inDelimiter),
trdsql.InHeader(inHeader),
trdsql.InSkip(inSkip),
trdsql.InPreRead(preRead),
trdsql.InLimitRead(limitRead),
trdsql.InJQ(inJQuery),
trdsql.InNeedNULL(inNull.valid),
trdsql.InNULL(inNull.str),
trdsql.InRowNumber(inRowNumber),
)
writer := cli.OutStream
if outFile != "" {
writer, err = os.Create(outFile)
if err != nil {
log.Printf("%s", err)
return 1
}
}
outFormat := outputFormat(outFlag)
if outFormat == trdsql.GUESS {
if outWithoutGuess {
outFormat = trdsql.CSV
} else {
outFormat = outGuessFormat(outFile)
}
}
if outCompression == "" && !outWithoutGuess {
outCompression = outGuessCompression(outFile)
}
writer, err = compressionWriter(writer, outCompression)
if err != nil {
log.Printf("%s", err)
return 1
}
w := trdsql.NewWriter(
trdsql.OutFormat(outFormat),
trdsql.OutDelimiter(outDelimiter),
trdsql.OutQuote(outQuote),
trdsql.OutAllQuotes(outAllQuotes),
trdsql.OutUseCRLF(outUseCRLF),
trdsql.OutHeader(outHeader),
trdsql.OutNoWrap(outNoWrap),
trdsql.OutNeedNULL(outNull.valid),
trdsql.OutNULL(outNull.str),
trdsql.OutStream(writer),
trdsql.ErrStream(cli.ErrStream),
)
exporter := trdsql.NewExporter(w)
trd := trdsql.NewTRDSQL(importer, exporter)
if driver != "" {
trd.Driver = driver
}
if dsn != "" {
trd.Dsn = dsn
}
ctx := context.Background()
if err = trd.ExecContext(ctx, query); err != nil {
log.Printf("%s", err)
return 1
}
if wc, ok := writer.(io.Closer); ok {
err = wc.Close()
if err != nil {
log.Printf("%s", err)
return 1
}
}
return 0
}
// Usage is outputs usage information.
func Usage(flags *flag.FlagSet) {
bold := gchalk.Bold
fmt.Fprintf(flags.Output(), "%s - Execute SQL queries on CSV, LTSV, JSON, YAML and TBLN.\n\n", trdsql.AppName)
fmt.Fprintf(flags.Output(), "%s\n", bold("Usage"))
fmt.Fprintf(flags.Output(), "\t%s [OPTIONS] [SQL(SELECT...)]\n\n", trdsql.AppName)
global := []string{}
input := []string{}
inputF := []string{}
output := []string{}
outputF := []string{}
flags.VisitAll(func(flag *flag.Flag) {
switch flag.Name[0] {
case 'i':
if isInFormat(flag.Name) {
inputF = append(inputF, usageFlag(flag))
} else {
input = append(input, usageFlag(flag))
}
case 'o':
if isOutFormat(flag.Name) {
outputF = append(outputF, usageFlag(flag))
} else {
output = append(output, usageFlag(flag))
}
default:
global = append(global, usageFlag(flag))
}
})
fmt.Fprintf(flags.Output(), "%s\n", bold("Options:"))
for _, u := range global {
fmt.Fprint(flags.Output(), u, "\n")
}
fmt.Fprintf(flags.Output(), "\n%s\n", bold("Input Formats:"))
for _, u := range inputF {
fmt.Fprint(flags.Output(), u, "\n")
}
fmt.Fprintf(flags.Output(), "\n%s\n", bold("Input options:"))
for _, u := range input {
fmt.Fprint(flags.Output(), u, "\n")
}
fmt.Fprintf(flags.Output(), "\n%s\n", bold("Output Formats:"))
for _, u := range outputF {
fmt.Fprint(flags.Output(), u, "\n")
}
fmt.Fprintf(flags.Output(), "\n%s\n", bold("Output options:"))
for _, u := range output {
fmt.Fprint(flags.Output(), u, "\n")
}
fmt.Fprintf(flags.Output(), "\n%s\n", bold("Examples:"))
fmt.Fprintf(flags.Output(), " $ trdsql \"SELECT c1,c2 FROM test.csv\"\n")
fmt.Fprintf(flags.Output(), " $ trdsql -oltsv \"SELECT c1,c2 FROM test.json::.items\"\n")
fmt.Fprintf(flags.Output(), " $ cat test.csv | trdsql -icsv -oltsv \"SELECT c1,c2 FROM -\"\n")
}
func usageFlag(f *flag.Flag) string {
vType, usage := flag.UnquoteUsage(f)
name := f.Name
if vType != "" {
name += " " + vType
}
s := fmt.Sprintf(" -%-18s %s", name, usage)
if f.DefValue == "0" || f.DefValue == "1" {
s += fmt.Sprintf(" (default %v)", f.DefValue)
} else if f.DefValue != "" && f.DefValue != "false" {
s += fmt.Sprintf(" (default %q)", f.DefValue)
}
return s
}
func printDBList(w io.Writer, cfg *config) {
for od, odb := range cfg.Database {
fmt.Fprintf(w, "%s:%s\n", od, odb.Driver)
}
}
func quoteOpts(opts *trdsql.AnalyzeOpts, driver string) *trdsql.AnalyzeOpts {
if driver == "postgres" {
opts.Quote = `\"`
}
return opts
}
func optsCommand(opts *trdsql.AnalyzeOpts, args []string) *trdsql.AnalyzeOpts {
command := args[0]
omitFlag := false
for _, arg := range args[1:] {
if omitFlag {
omitFlag = false
continue
}
if arg == "-a" || arg == "-A" {
omitFlag = true
continue
}
if arg == "-ijq" {
omitFlag = true
continue
}
if len(arg) <= 1 || arg[0] != '-' {
arg = quotedArg(arg)
}
command += " " + arg
}
opts.Command = command
return opts
}
func trimQuery(query string) string {
return strings.TrimRight(strings.TrimSpace(query), ";")
}
func getQuery(args []string, tableName string, queryFile string) (string, error) {
if tableName != "" {
var query strings.Builder
query.WriteString(TableQuery)
query.WriteString(" ")
query.WriteString(tableName)
return trimQuery(query.String()), nil
}
if queryFile == "" {
return trimQuery(strings.Join(args, " ")), nil
}
sqlByte, err := os.ReadFile(queryFile)
if err != nil {
return "", err
}
return trimQuery(string(sqlByte)), nil
}
func getDB(cfg *config, cDB string, cDriver string, cDSN string) (string, string) {
if cDB == "" {
cDB = cfg.Db
}
if Debug {
for od, odb := range cfg.Database {
if cDB == od {
log.Printf(">[driver: %s:%s:%s]", od, odb.Driver, odb.Dsn)
} else {
log.Printf(" [driver: %s:%s:%s]", od, odb.Driver, odb.Dsn)
}
}
}
if cDriver != "" {
return cDriver, cDSN
}
if cDSN != "" {
return "", cDSN
}
if cDB != "" {
if cfg.Database[cDB].Driver == "" {
log.Printf("ERROR: db[%s] does not found", cDB)
} else {
return cfg.Database[cDB].Driver, cfg.Database[cDB].Dsn
}
}
return "", ""
}
var argQuote = regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
func quotedArg(arg string) string {
if argQuote.MatchString(arg) {
return arg
}
return `"` + arg + `"`
}
// inputFlag represents the format of the input.
type inputFlag struct {
CSV bool
LTSV bool
JSON bool
YAML bool
TBLN bool
WIDTH bool
TEXT bool
}
// inputFormat returns format from flag.
func inputFormat(i inputFlag) trdsql.Format {
switch {
case i.CSV:
return trdsql.CSV
case i.LTSV:
return trdsql.LTSV
case i.JSON:
return trdsql.JSON
case i.YAML:
return trdsql.YAML
case i.TBLN:
return trdsql.TBLN
case i.WIDTH:
return trdsql.WIDTH
case i.TEXT:
return trdsql.TEXT
default:
return trdsql.GUESS
}
}
func isInFormat(name string) bool {
switch name {
case "ig", "icsv", "iltsv", "ijson", "iyaml", "itbln", "iwidth", "itext":
return true
}
return false
}
// outputFlag represents the format of the output.
type outputFlag struct {
CSV bool
LTSV bool
JSON bool
JSONL bool
YAML bool
TBLN bool
AT bool
MD bool
VF bool
RAW bool
}
// outFormat returns format from flag.
func outputFormat(o outputFlag) trdsql.Format {
switch {
case o.LTSV:
return trdsql.LTSV
case o.JSON:
return trdsql.JSON
case o.RAW:
return trdsql.RAW
case o.MD:
return trdsql.MD
case o.AT:
return trdsql.AT
case o.VF:
return trdsql.VF
case o.TBLN:
return trdsql.TBLN
case o.JSONL:
return trdsql.JSONL
case o.YAML:
return trdsql.YAML
case o.CSV:
return trdsql.CSV
default:
return trdsql.GUESS
}
}
func isOutFormat(name string) bool {
switch name {
case "ocsv", "oltsv", "ojson", "ojsonl", "oyaml", "otbln", "oat", "omd", "ovf", "oraw":
return true
}
return false
}
func outGuessFormat(fileName string) trdsql.Format {
for {
dotExt := filepath.Ext(fileName)
if dotExt == "" {
return trdsql.CSV
}
ext := strings.ToUpper(strings.TrimLeft(dotExt, "."))
format := trdsql.OutputFormat(ext)
if format != trdsql.GUESS {
return format
}
fileName = fileName[0 : len(fileName)-len(dotExt)]
}
}
func outGuessCompression(fileName string) string {
dotExt := filepath.Ext(fileName)
ext := strings.ToLower(strings.TrimLeft(dotExt, "."))
switch ext {
case "gz":
return "gzip"
case "bz2":
return "bzip2"
case "zst":
return "zstd"
case "lz4":
return "lz4"
case "xz":
return "xz"
default:
return ""
}
}
func compressionWriter(w io.Writer, compression string) (io.Writer, error) {
switch strings.ToLower(compression) {
case "gz", "gzip":
return gzip.NewWriter(w), nil
case "bz2", "bzip2":
return bzip2.NewWriter(w, &bzip2.WriterConfig{})
case "zst", "zstd":
return zstd.NewWriter(w)
case "lz4":
return lz4.NewWriter(w), nil
case "xz":
return xz.NewWriter(w)
default:
return w, nil
}
}
package cmd
import (
"encoding/json"
"errors"
"fmt"
"io"
"log"
"os"
"path/filepath"
"runtime"
"github.com/noborus/trdsql"
)
// ErrNoFile is returned when there is no file.
var ErrNoFile = errors.New("no file")
type database struct {
Driver string `json:"driver"`
Dsn string `json:"dsn"`
}
type config struct {
Db string `json:"db"`
Database map[string]database `json:"database"`
}
func configOpen(config string) io.Reader {
var fileName string
switch {
case config != "":
fileName = config
case runtime.GOOS == "windows":
fileName = filepath.Join(os.Getenv("APPDATA"), trdsql.AppName, "config.json")
default:
fileName = filepath.Join(os.Getenv("HOME"), ".config", trdsql.AppName, "config.json")
}
cfg, err := os.Open(fileName)
if err != nil {
if Debug {
log.Printf("configOpen: %s", err.Error())
}
return nil
}
if Debug {
log.Printf("config found: %s", fileName)
}
return cfg
}
func loadConfig(conf io.Reader) (*config, error) {
var cfg config
if conf == nil {
return &cfg, ErrNoFile
}
err := json.NewDecoder(conf).Decode(&cfg)
if err != nil {
return &cfg, fmt.Errorf("config error: %w", err)
}
return &cfg, nil
}
package main
import (
"os"
"github.com/noborus/trdsql/cmd"
)
func main() {
cli := cmd.Cli{
OutStream: os.Stdout,
ErrStream: os.Stderr,
}
os.Exit(cli.Run(os.Args))
}
package trdsql
import (
"context"
"database/sql"
"errors"
"fmt"
"io"
"log"
"strings"
)
var (
// ErrNoTransaction is returned if SQL is executed when a transaction has not started.
// SQL must be executed within a transaction.
ErrNoTransaction = errors.New("transaction has not been started")
// ErrNilReader is returned by Set reader of the specified file is nil error.
ErrNilReader = errors.New("nil reader")
// ErrInvalidNames is returned by Set if invalid names(number of columns is 0).
ErrInvalidNames = errors.New("invalid names")
// ErrInvalidTypes is returned by Set if invalid column types (does not match the number of column names).
ErrInvalidTypes = errors.New("invalid types")
// ErrNoStatement is returned by no SQL statement.
ErrNoStatement = errors.New("no SQL statement")
)
// DB represents database information.
type DB struct {
// driver holds the sql driver as a string.
driver string
// dsn holds dsn of sql as a character string.
dsn string
// quote is the quote character(s) that varies depending on the sql driver.
// PostgreSQL is ("), sqlite3 and mysql is (`).
quote string
// maxBulk is the maximum number of bundles for bulk insert.
// The number of columns x rows is less than maxBulk.
maxBulk int
// *sql.DB represents the database connection.
*sql.DB
// Tx represents a database transaction.
Tx *sql.Tx
// importCount represents the number that is incremented to identify imported files.
importCount int
}
// Disconnect is disconnect the database.
func (db *DB) Disconnect() error {
return db.Close()
}
// CreateTable is create a (temporary) table in the database.
// The arguments are the table name, column name, column type, and temporary flag.
func (db *DB) CreateTable(tableName string, columnNames []string, columnTypes []string, isTemporary bool) error {
return db.CreateTableContext(context.Background(), tableName, columnNames, columnTypes, isTemporary)
}
// CreateTableContext is create a (temporary) table in the database.
// The arguments are the table name, column name, column type, and temporary flag.
func (db *DB) CreateTableContext(ctx context.Context, tableName string, columnNames []string, columnTypes []string, isTemporary bool) error {
if db.Tx == nil {
return ErrNoTransaction
}
if len(columnNames) == 0 {
return ErrInvalidNames
}
if len(columnNames) != len(columnTypes) {
return ErrInvalidTypes
}
query := db.queryCreateTable(tableName, columnNames, columnTypes, isTemporary)
debug.Printf(query)
_, err := db.Tx.ExecContext(ctx, query)
return err
}
func (db *DB) queryCreateTable(tableName string, columnNames []string, columnTypes []string, isTemporary bool) string {
var buf strings.Builder
if isTemporary {
buf.WriteString("CREATE TEMPORARY TABLE ")
} else {
buf.WriteString("CREATE TABLE ")
}
buf.WriteString(tableName)
buf.WriteString(" ( ")
buf.WriteString(db.QuotedName(columnNames[0]))
buf.WriteString(" ")
buf.WriteString(columnTypes[0])
for i, columnName := range columnNames[1:] {
buf.WriteString(", ")
buf.WriteString(db.QuotedName(columnName))
buf.WriteString(" ")
buf.WriteString(columnTypes[i+1])
}
buf.WriteString(" );")
return buf.String()
}
// importTable represents the table data to be imported.
type importTable struct {
tableName string
columns []string
row []any
maxCap int
lastCount int
count int
}
// Import is imports data into a table.
func (db *DB) Import(tableName string, columnNames []string, reader Reader) error {
return db.ImportContext(context.Background(), tableName, columnNames, reader)
}
// ImportContext is imports data into a table.
func (db *DB) ImportContext(ctx context.Context, tableName string, columnNames []string, reader Reader) error {
if db.Tx == nil {
return ErrNoTransaction
}
if reader == nil {
return ErrNilReader
}
columns := make([]string, len(columnNames))
for i := range columnNames {
columns[i] = db.QuotedName(columnNames[i])
}
row := make([]any, len(columnNames))
table := &importTable{
tableName: tableName,
columns: columns,
row: row,
lastCount: 0,
count: 0,
}
if db.driver == "postgres" {
return db.copyImport(ctx, table, reader)
}
return db.insertImport(ctx, table, reader)
}
// copyImport adds rows to a table with the COPY clause (PostgreSQL only).
func (db *DB) copyImport(ctx context.Context, table *importTable, reader Reader) error {
query := queryCopy(table)
debug.Printf(query)
stmt, err := db.Tx.PrepareContext(ctx, query)
if err != nil {
return fmt.Errorf("COPY prepare: %w", err)
}
defer db.stmtClose(stmt)
preReadRows := reader.PreReadRow()
for _, row := range preReadRows {
if row == nil {
break
}
if _, err = stmt.ExecContext(ctx, row...); err != nil {
return err
}
}
for {
table.row, err = reader.ReadRow(table.row)
if err != nil {
if errors.Is(err, io.EOF) {
break
}
return fmt.Errorf("COPY read: %w", err)
}
// Skip when empty read.
if len(table.row) == 0 {
continue
}
if _, err = stmt.ExecContext(ctx, table.row...); err != nil {
return err
}
}
_, err = stmt.ExecContext(ctx)
return err
}
// queryCopy constructs a SQL COPY statement.
func queryCopy(table *importTable) string {
var buf strings.Builder
buf.WriteString("COPY ")
buf.WriteString(table.tableName)
buf.WriteString(" (")
buf.WriteString(table.columns[0])
for _, column := range table.columns[1:] {
buf.WriteString(", ")
buf.WriteString(column)
}
buf.WriteString(") FROM STDIN;")
return buf.String()
}
// insertImport adds a row to a table with an INSERT clause.
// Insert multiple rows by bulk insert.
func (db *DB) insertImport(ctx context.Context, table *importTable, reader Reader) error {
var err error
var stmt *sql.Stmt
defer db.stmtClose(stmt)
if len(table.row) > db.maxBulk {
table.maxCap = len(table.row)
} else {
table.maxCap = (db.maxBulk / len(table.row)) * len(table.row)
}
bulk := make([]any, 0, table.maxCap)
preRows := reader.PreReadRow()
preRowNum := len(preRows)
preCount := 0
for eof := false; !eof; {
if preCount < preRowNum {
// PreRead
for preCount < preRowNum {
row := preRows[preCount]
bulk = append(bulk, row...)
table.count++
preCount++
if (table.count * len(table.row)) > table.maxCap {
break
}
}
} else {
// Read
bulk, err = bulkPush(ctx, table, reader, bulk)
if err != nil {
if !errors.Is(err, io.EOF) {
return fmt.Errorf("bulk read: %w", err)
}
eof = true
if len(bulk) == 0 {
return nil
}
}
}
stmt, err = db.bulkStmtOpen(ctx, table, stmt)
if err != nil {
return err
}
if _, err := stmt.ExecContext(ctx, bulk...); err != nil {
return err
}
bulk = bulk[:0]
table.count = 0
}
return nil
}
func (db *DB) stmtClose(stmt *sql.Stmt) {
if stmt == nil {
return
}
if err := stmt.Close(); err != nil {
log.Printf("ERROR: stmtClose:%s", err)
}
}
func bulkPush(ctx context.Context, table *importTable, input Reader, bulk []any) ([]any, error) {
for (table.count * len(table.row)) < table.maxCap {
row, err := input.ReadRow(table.row)
if err != nil {
return bulk, err
}
// Skip when empty read.
if len(row) == 0 {
continue
}
bulk = append(bulk, row...)
table.count++
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
}
return bulk, nil
}
func (db *DB) bulkStmtOpen(ctx context.Context, table *importTable, stmt *sql.Stmt) (*sql.Stmt, error) {
if table.lastCount == table.count {
return stmt, nil
}
db.stmtClose(stmt)
stmt, err := db.insertPrepare(ctx, table)
if err != nil {
return nil, err
}
table.lastCount = table.count
return stmt, nil
}
func (db *DB) insertPrepare(ctx context.Context, table *importTable) (*sql.Stmt, error) {
query := queryInsert(table)
debug.Printf(query)
stmt, err := db.Tx.PrepareContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("INSERT Prepare: %s:%w", query, err)
}
return stmt, nil
}
// queryInsert constructs a SQL INSERT statement.
func queryInsert(table *importTable) string {
var buf strings.Builder
buf.WriteString("INSERT INTO ")
buf.WriteString(table.tableName)
buf.WriteString(" (")
buf.WriteString(table.columns[0])
for _, column := range table.columns[1:] {
buf.WriteString(", ")
buf.WriteString(column)
}
buf.WriteString(") VALUES ")
buf.WriteString("(")
buf.WriteString("?")
for i := 1; i < len(table.columns); i++ {
buf.WriteString(",?")
}
buf.WriteString(")")
for i := 1; i < table.count; i++ {
buf.WriteString(",(")
buf.WriteString("?")
for j := 1; j < len(table.columns); j++ {
buf.WriteString(",?")
}
buf.WriteString(")")
}
buf.WriteString(";")
return buf.String()
}
// QuotedName returns the table name quoted.
// Returns as is, if already quoted.
func (db *DB) QuotedName(orgName string) string {
if orgName == "" {
return ""
}
if orgName[0] == db.quote[0] {
return orgName
}
var buf strings.Builder
buf.WriteString(db.quote)
buf.WriteString(orgName)
buf.WriteString(db.quote)
return buf.String()
}
// Select is executes SQL select statements.
func (db *DB) Select(query string) (*sql.Rows, error) {
return db.SelectContext(context.Background(), query)
}
// SelectContext is executes SQL select statements with context.
// SelectContext is a wrapper for QueryContext.
func (db *DB) SelectContext(ctx context.Context, query string) (*sql.Rows, error) {
rows, err := db.Tx.QueryContext(ctx, query)
if err != nil {
return rows, fmt.Errorf("%w [%s]", err, query)
}
return rows, nil
}
func (db *DB) OtherExecContext(ctx context.Context, query string) error {
_, err := db.Tx.ExecContext(ctx, query)
if err != nil {
return fmt.Errorf("%w [%s]", err, query)
}
return nil
}
//go:build cgo
package trdsql
import (
"database/sql"
// MySQL driver.
_ "github.com/go-sql-driver/mysql"
// PostgreSQL driver.
_ "github.com/lib/pq"
// SQLite3 driver.
_ "github.com/mattn/go-sqlite3"
// SQlite3 extension library.
sqlite3_stdlib "github.com/multiprocessio/go-sqlite3-stdlib"
)
var DefaultDriver = "sqlite3"
func init() {
// Enable sqlite3 extensions.
// It can be used by setting the driver to "sqlite3_ext".
sqlite3_stdlib.Register("sqlite3_ext")
}
// Connect is connects to the database.
// Currently supported drivers are sqlite3, mysql, postgres.
// Set quote character and maxBulk depending on the driver type.
func Connect(driver, dsn string) (*DB, error) {
sqlDB, err := sql.Open(driver, dsn)
if err != nil {
return nil, err
}
db := &DB{
DB: sqlDB,
driver: driver,
dsn: dsn,
}
debug.Printf("driver: %s, dsn: %s", driver, dsn)
switch driver {
case "sqlite3", "sqlite3_ext", "sqlite":
db.quote = "`"
db.maxBulk = 1000
case "mysql":
db.quote = "`"
db.maxBulk = 1000
case "postgres":
db.quote = "\""
default:
db.quote = "\""
}
return db, nil
}
package trdsql
import "log"
// debugT is a type of debug flag.
type debugT bool
// debug is a flag for detailed output.
var debug = debugT(false)
// EnableDebug is enable verbose output for debug.
func EnableDebug() {
debug = true
}
func (d debugT) Printf(format string, args ...any) {
if d {
log.Printf(format, args...)
}
}
package trdsql
import (
"context"
"database/sql"
"log"
"strings"
"github.com/noborus/sqlss"
)
// Exporter is the interface for processing query results.
// Exporter executes SQL and outputs to Writer.
type Exporter interface {
Export(db *DB, sql string) error
ExportContext(ctx context.Context, db *DB, sql string) error
}
// WriteFormat represents a structure that satisfies Exporter.
type WriteFormat struct {
Writer
columns []string
types []string
multi bool
}
// NewExporter returns trdsql default Exporter.
func NewExporter(writer Writer) *WriteFormat {
return &WriteFormat{
Writer: writer,
multi: false,
}
}
// Export is execute SQL(Select) and the result is written out by the writer.
// Export is called from Exec.
func (e *WriteFormat) Export(db *DB, sql string) error {
ctx := context.Background()
return e.ExportContext(ctx, db, sql)
}
// ExportContext is execute SQL(Select) and the result is written out by the writer.
// ExportContext is called from ExecContext.
func (e *WriteFormat) ExportContext(ctx context.Context, db *DB, sqlQuery string) error {
queries := sqlss.SplitQueries(sqlQuery)
if !multi || len(queries) == 1 {
return e.exportContext(ctx, db, sqlQuery)
}
e.multi = true
for _, query := range queries {
if err := e.exportContext(ctx, db, query); err != nil {
return err
}
}
return nil
}
func (e *WriteFormat) exportContext(ctx context.Context, db *DB, query string) error {
if db.Tx == nil {
return ErrNoTransaction
}
query = strings.TrimSpace(query)
if query == "" {
return ErrNoStatement
}
debug.Printf(query)
if db.isExecContext(query) {
return db.OtherExecContext(ctx, query)
}
rows, err := db.SelectContext(ctx, query)
if err != nil {
return err
}
columns, err := rows.Columns()
if err != nil {
return err
}
e.columns = columns
defer func() {
if err = rows.Close(); err != nil {
log.Printf("ERROR: close:%s", err)
}
}()
// No data is not output for multiple queries.
if e.multi && len(e.columns) == 0 {
return nil
}
columnTypes, err := rows.ColumnTypes()
if err != nil {
return err
}
types := make([]string, len(columns))
for i, ct := range columnTypes {
types[i] = ct.DatabaseTypeName()
}
e.types = types
return e.write(ctx, rows)
}
func (e *WriteFormat) write(ctx context.Context, rows *sql.Rows) error {
values := make([]any, len(e.columns))
scanArgs := make([]any, len(e.columns))
for i := range values {
scanArgs[i] = &values[i]
}
if err := e.Writer.PreWrite(e.columns, e.types); err != nil {
return err
}
for rows.Next() {
select {
case <-ctx.Done(): // cancellation
return ctx.Err()
default:
}
if err := rows.Scan(scanArgs...); err != nil {
return err
}
if err := e.Writer.WriteRow(values, e.columns); err != nil {
return err
}
}
if err := rows.Err(); err != nil {
return err
}
return e.Writer.PostWrite()
}
// isExecContext returns true if the query is not a SELECT statement.
// Queries that return no rows in SQlite should use ExecContext and therefore return true.
func (db *DB) isExecContext(query string) bool {
if db.driver == "sqlite3" || db.driver == "sqlite" {
return !strings.HasPrefix(strings.ToUpper(query), "SELECT")
}
return false
}
package trdsql
import (
"bufio"
"bytes"
"compress/bzip2"
"compress/gzip"
"context"
"errors"
"fmt"
"io"
"log"
"os"
"os/user"
"path/filepath"
"regexp"
"strings"
"github.com/klauspost/compress/zstd"
"github.com/pierrec/lz4/v4"
"github.com/ulikunitz/xz"
)
var (
// ErrInvalidColumn is returned if invalid column.
ErrInvalidColumn = errors.New("invalid column")
// ErrNoReader is returned when there is no reader.
ErrNoReader = errors.New("no reader")
// ErrUnknownFormat is returned if the format is unknown.
ErrUnknownFormat = errors.New("unknown format")
// ErrNoRows returned when there are no rows.
ErrNoRows = errors.New("no rows")
// ErrUnableConvert is returned if it cannot be converted to a table.
ErrUnableConvert = errors.New("unable to convert")
// ErrNoMatchFound is returned if no match is found.
ErrNoMatchFound = errors.New("no match found")
// ErrNonDefinition is returned when there is no definition.
ErrNonDefinition = errors.New("no definition")
// ErrInvalidJSON is returned when the JSON is invalid.
ErrInvalidJSON = errors.New("invalid JSON")
// ErrInvalidYAML is returned when the YAML is invalid.
ErrInvalidYAML = errors.New("invalid YAML")
)
// Importer is the interface import data into the database.
// Importer parses sql query to decide which file to Import.
// Therefore, the reader does not receive it directly.
type Importer interface {
Import(db *DB, query string) (string, error)
ImportContext(ctx context.Context, db *DB, query string) (string, error)
}
// ReadFormat represents a structure that satisfies the Importer.
type ReadFormat struct {
*ReadOpts
}
// NewImporter returns trdsql default Importer.
// The argument is an option of Functional Option Pattern.
//
// usage:
//
// trdsql.NewImporter(
// trdsql.InFormat(trdsql.CSV),
// trdsql.InHeader(true),
// trdsql.InDelimiter(";"),
// )
func NewImporter(options ...ReadOpt) *ReadFormat {
readOpts := NewReadOpts(options...)
return &ReadFormat{
ReadOpts: readOpts,
}
}
// DefaultDBType is default type.
const DefaultDBType = "text"
// Import is parses the SQL statement and imports one or more tables.
// Import is called from Exec.
// Return the rewritten SQL and error.
// No error is returned if there is no table to import.
func (i *ReadFormat) Import(db *DB, query string) (string, error) {
ctx := context.Background()
return i.ImportContext(ctx, db, query)
}
// ImportContext is parses the SQL statement and imports one or more tables.
// ImportContext is called from ExecContext.
// Return the rewritten SQL and error.
// No error is returned if there is no table to import.
func (i *ReadFormat) ImportContext(ctx context.Context, db *DB, query string) (string, error) {
parsedQuery := SQLFields(query)
tables, tableIdx := TableNames(parsedQuery)
if len(tables) == 0 {
// without FROM clause. ex. SELECT 1+1;
debug.Printf("table not found\n")
return query, nil
}
for fileName := range tables {
tableName, err := ImportFileContext(ctx, db, fileName, i.ReadOpts)
if err != nil {
return query, err
}
if len(tableName) > 0 {
tables[fileName] = tableName
}
}
// replace table names in query with their quoted values
for _, idx := range tableIdx {
if table, ok := tables[parsedQuery[idx]]; ok {
parsedQuery[idx] = table
}
}
// reconstruct the query with quoted table names
query = strings.Join(parsedQuery, "")
return query, nil
}
// TableNames returns a map of table names
// that may be tables by a simple SQL parser
// from the query string of the argument,
// along with the locations within the parsed
// query where those table names were found.
func TableNames(parsedQuery []string) (map[string]string, []int) {
tables := make(map[string]string)
tableIdx := []int{}
tableFlag := false
frontFlag := false
debug.Printf("[%s]", strings.Join(parsedQuery, "]["))
for i, w := range parsedQuery {
switch {
case strings.Contains(" \t\r\n;=", w): // nolint // Because each character is parsed by SQLFields.
continue
case strings.EqualFold(w, "FROM"),
strings.EqualFold(w, "*FROM"),
strings.EqualFold(w, "JOIN"),
strings.EqualFold(w, "TABLE"),
strings.EqualFold(w, "INTO"),
strings.EqualFold(w, "UPDATE"):
tableFlag = true
frontFlag = true
case isSQLKeyWords(w):
tableFlag = false
case w == ",":
frontFlag = true
default:
if tableFlag && frontFlag {
if w[len(w)-1] == ')' {
w = w[:len(w)-1]
}
if !isSQLKeyWords(w) {
tables[w] = w
tableIdx = append(tableIdx, i)
}
}
frontFlag = false
}
}
return tables, tableIdx
}
// SQLFields returns an array of string fields
// (interpreting quotes) from the argument query.
func SQLFields(query string) []string {
parsed := make([]string, 0, len(query)/2)
buf := new(bytes.Buffer)
var singleQuoted, doubleQuoted, backQuote bool
for _, r := range query {
switch r {
case ' ', '\t', '\r', '\n', ',', ';', '=', '(', ')':
if !singleQuoted && !doubleQuoted && !backQuote {
if buf.Len() != 0 {
parsed = append(parsed, buf.String())
buf.Reset()
}
parsed = append(parsed, string(r))
} else {
buf.WriteRune(r)
}
continue
case '\'':
if !doubleQuoted && !backQuote {
singleQuoted = !singleQuoted
}
case '"':
if !singleQuoted && !backQuote {
doubleQuoted = !doubleQuoted
}
case '`':
if !singleQuoted && !doubleQuoted {
backQuote = !backQuote
}
case '*':
str := buf.String()
if strings.ToUpper(str) == "SELECT" { // `SELECT*` to `SELECT *`
parsed = append(parsed, str)
parsed = append(parsed, string(r))
buf.Reset()
continue
}
}
buf.WriteRune(r)
}
if buf.Len() > 0 {
parsed = append(parsed, buf.String())
}
return parsed
}
func isSQLKeyWords(str string) bool {
switch strings.ToUpper(str) {
case "WHERE", "GROUP", "HAVING", "WINDOW", "UNION", "ORDER", "LIMIT", "OFFSET", "FETCH",
"FOR", "LEFT", "RIGHT", "CROSS", "INNER", "FULL", "LATERAL", "(SELECT":
return true
}
return false
}
// ImportFile is imports a file.
// Return the quoted table name and error.
// Do not import if file not found (no error).
// Wildcards can be passed as fileName.
func ImportFile(db *DB, fileName string, readOpts *ReadOpts) (string, error) {
return ImportFileContext(context.Background(), db, fileName, readOpts)
}
// ImportFileContext is imports a file.
// Return the quoted table name and error.
// Do not import if file not found (no error).
// Wildcards can be passed as fileName.
func ImportFileContext(ctx context.Context, db *DB, fileName string, readOpts *ReadOpts) (string, error) {
opts, fileName := GuessOpts(readOpts, fileName)
db.importCount++
file, err := importFileOpen(fileName)
if err != nil {
debug.Printf("%s\n", err)
return "", nil
}
defer func() {
if deferr := file.Close(); deferr != nil {
log.Printf("file close:%s", deferr)
}
}()
reader, err := NewReader(file, opts)
if err != nil {
return "", err
}
tableName := fileName
if opts.InJQuery != "" {
tableName = fmt.Sprintf("%s::jq%d", fileName, db.importCount)
}
tableName = db.QuotedName(tableName)
if opts.InRowNumber {
reader = newRowNumberReader(reader)
}
columnNames, err := reader.Names()
if err != nil {
if !errors.Is(err, io.EOF) {
return tableName, err
}
debug.Printf("EOF reached before argument number of rows")
}
columnTypes, err := reader.Types()
if err != nil {
if !errors.Is(err, io.EOF) {
return tableName, err
}
debug.Printf("EOF reached before argument number of rows")
}
debug.Printf("Column Names: [%v]", strings.Join(columnNames, ","))
debug.Printf("Column Types: [%v]", strings.Join(columnTypes, ","))
if err := db.CreateTableContext(ctx, tableName, columnNames, columnTypes, opts.IsTemporary); err != nil {
return tableName, err
}
return tableName, db.ImportContext(ctx, tableName, columnNames, reader)
}
// GuessOpts guesses ReadOpts from the file name and sets it.
func GuessOpts(readOpts *ReadOpts, fileName string) (*ReadOpts, string) {
if _, err := os.Stat(fileName); err != nil {
if idx := strings.Index(fileName, "::"); idx != -1 {
// jq expression.
readOpts.InJQuery = fileName[idx+2:]
fileName = fileName[:idx]
}
}
if readOpts.InFormat != GUESS {
readOpts.realFormat = readOpts.InFormat
return readOpts, fileName
}
format := guessFormat(fileName)
readOpts.realFormat = format
debug.Printf("Guess file type as %s: [%s]", readOpts.realFormat, fileName)
return readOpts, fileName
}
// guessFormat is guess format from the file name extension.
// Format extensions are searched recursively to remove
// compression extensions such as .gz.
func guessFormat(fileName string) Format {
fileName = strings.TrimRight(fileName, "\"'`")
for {
dotExt := filepath.Ext(fileName)
if dotExt == "" {
debug.Printf("Set in CSV because the extension is unknown: [%s]", fileName)
return CSV
}
ext := strings.ToUpper(strings.TrimLeft(dotExt, "."))
if format, ok := extToFormat[ext]; ok {
return format
}
fileName = fileName[:len(fileName)-len(dotExt)]
}
}
// importFileOpen opens the file specified as a table.
func importFileOpen(tableName string) (io.ReadCloser, error) {
r := regexp.MustCompile(`\*|\?|\[`)
if r.MatchString(tableName) {
return globFileOpen(tableName)
}
return singleFileOpen(tableName)
}
// uncompressedReader returns the decompressed reader
// if it is a compressed file.
func uncompressedReader(reader io.Reader) io.ReadCloser {
var err error
buf := [7]byte{}
n, err := io.ReadAtLeast(reader, buf[:], len(buf))
if err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
return io.NopCloser(bytes.NewReader(buf[:n]))
}
return io.NopCloser(bytes.NewReader(nil))
}
rd := io.MultiReader(bytes.NewReader(buf[:n]), reader)
var r io.ReadCloser
switch {
case bytes.Equal(buf[:3], []byte{0x1f, 0x8b, 0x8}):
r, err = gzip.NewReader(rd)
case bytes.Equal(buf[:3], []byte{0x42, 0x5A, 0x68}):
r = io.NopCloser(bzip2.NewReader(rd))
case bytes.Equal(buf[:4], []byte{0x28, 0xb5, 0x2f, 0xfd}):
var zr *zstd.Decoder
zr, err = zstd.NewReader(rd)
r = io.NopCloser(zr)
case bytes.Equal(buf[:4], []byte{0x04, 0x22, 0x4d, 0x18}):
r = io.NopCloser(lz4.NewReader(rd))
case bytes.Equal(buf[:7], []byte{0xfd, 0x37, 0x7a, 0x58, 0x5a, 0x0, 0x0}):
var zr *xz.Reader
zr, err = xz.NewReader(rd)
r = io.NopCloser(zr)
}
if err != nil || r == nil {
r = io.NopCloser(rd)
}
return r
}
// singleFileOpen opens one file. Also interpret stdin.
func singleFileOpen(fileName string) (io.ReadCloser, error) {
if len(fileName) == 0 || fileName == "-" || strings.ToLower(fileName) == "stdin" {
return uncompressedReader(bufio.NewReader(os.Stdin)), nil
}
fileName = expandTilde(trimQuote(fileName))
file, err := os.Open(fileName)
if err != nil {
return nil, err
}
return uncompressedReader(file), nil
}
// globFileOpen expands the file path,
// connects multiple files and returns one io.PipeReader.
func globFileOpen(globName string) (*io.PipeReader, error) {
globName = expandTilde(trimQuote(globName))
fileNames, err := filepath.Glob(globName)
if err != nil {
return nil, err
}
if len(fileNames) == 0 {
return nil, fmt.Errorf("%w: %s", ErrNoMatchFound, fileNames)
}
pipeReader, pipeWriter := io.Pipe()
go func() {
defer func() {
if err := pipeWriter.Close(); err != nil {
log.Printf("pipe close:%s", err)
}
}()
for _, fileName := range fileNames {
if err := copyFileOpen(pipeWriter, fileName); err != nil {
log.Printf("ERROR: %s:%s", fileName, err)
continue
}
}
}()
return pipeReader, nil
}
// copyFileOpen opens the file and copies it to the writer.
func copyFileOpen(writer io.Writer, fileName string) error {
debug.Printf("Open: [%s]", fileName)
file, err := os.Open(fileName)
if err != nil {
return err
}
r := uncompressedReader(file)
if _, err := io.Copy(writer, r); err != nil {
return err
}
// For if the file does not have a line break before EOF.
if _, err := writer.Write([]byte("\n")); err != nil {
return err
}
if err := file.Close(); err != nil {
return err
}
debug.Printf("Close: [%s]", fileName)
return nil
}
func expandTilde(fileName string) string {
if strings.HasPrefix(fileName, "~") {
usr, err := user.Current()
if err != nil {
log.Printf("ERROR: %s", err)
return fileName
}
fileName = filepath.Join(usr.HomeDir, fileName[1:])
}
return fileName
}
func trimQuote(str string) string {
if str[0] == '`' && str[len(str)-1] == '`' {
str = str[1 : len(str)-1]
}
if str[0] == '"' && str[len(str)-1] == '"' {
str = str[1 : len(str)-1]
}
return str
}
func trimQuoteAll(str string) string {
if len(str) < 2 {
return str
}
if str[0] == '\'' && str[len(str)-1] == '\'' {
return str[1 : len(str)-1]
}
if str[0] == '`' && str[len(str)-1] == '`' {
return str[1 : len(str)-1]
}
if str[0] == '"' && str[len(str)-1] == '"' {
return str[1 : len(str)-1]
}
return str
}
package trdsql
import (
"context"
"io"
)
// BufferImporter a structure that includes tableName and Reader.
type BufferImporter struct {
Reader
tableName string
}
// NewBufferImporter returns trdsql BufferImporter.
func NewBufferImporter(tableName string, r io.Reader, options ...ReadOpt) (*BufferImporter, error) {
readOpts := NewReadOpts(options...)
readOpts.realFormat = readOpts.InFormat
reader, err := NewReader(r, readOpts)
if err != nil {
return nil, err
}
return &BufferImporter{
tableName: tableName,
Reader: reader,
}, nil
}
// Import is a method to import from Reader in BufferImporter.
func (i *BufferImporter) Import(db *DB, query string) (string, error) {
ctx := context.Background()
return i.ImportContext(ctx, db, query)
}
// ImportContext is a method to import from Reader in BufferImporter.
func (i *BufferImporter) ImportContext(ctx context.Context, db *DB, query string) (string, error) {
names, err := i.Names()
if err != nil {
return query, err
}
types, err := i.Types()
if err != nil {
return query, err
}
if err := db.CreateTable(i.tableName, names, types, true); err != nil {
return query, err
}
return query, db.ImportContext(ctx, i.tableName, names, i.Reader)
}
package trdsql
import "context"
// SliceImporter is a structure that includes SliceReader.
// SliceImporter can be used as a library from another program.
// It is not used from the command.
// SliceImporter is an importer that reads one slice data.
type SliceImporter struct {
*SliceReader
}
// NewSliceImporter returns trdsql SliceImporter.
func NewSliceImporter(tableName string, data any) *SliceImporter {
return &SliceImporter{
SliceReader: NewSliceReader(tableName, data),
}
}
// Import is a method to import from SliceReader in SliceImporter.
func (i *SliceImporter) Import(db *DB, query string) (string, error) {
ctx := context.Background()
return i.ImportContext(ctx, db, query)
}
// ImportContext is a method to import from SliceReader in SliceImporter.
func (i *SliceImporter) ImportContext(ctx context.Context, db *DB, query string) (string, error) {
names, err := i.Names()
if err != nil {
return query, err
}
types, err := i.Types()
if err != nil {
return query, err
}
if err := db.CreateTable(i.tableName, names, types, true); err != nil {
return query, err
}
return query, db.ImportContext(ctx, i.tableName, names, i.SliceReader)
}
package trdsql
import (
"encoding/csv"
"errors"
"fmt"
"io"
"strconv"
)
// CSVReader provides methods of the Reader interface.
type CSVReader struct {
reader *csv.Reader
inNULL string
names []string
types []string
preRead [][]string
limitRead bool
needNULL bool
}
// NewCSVReader returns CSVReader and error.
func NewCSVReader(reader io.Reader, opts *ReadOpts) (*CSVReader, error) {
r := &CSVReader{}
r.reader = csv.NewReader(reader)
r.reader.LazyQuotes = true
r.reader.FieldsPerRecord = -1 // no check count
d, err := delimiter(opts.InDelimiter)
if err != nil {
return nil, err
}
r.reader.Comma = d
if r.reader.Comma == ' ' {
r.reader.TrimLeadingSpace = true
}
if opts.InSkip > 0 {
skipRead(r, opts.InSkip)
}
r.needNULL = opts.InNeedNULL
r.inNULL = opts.InNULL
r.limitRead = opts.InLimitRead
// Read the header.
preReadN := opts.InPreRead
if opts.InHeader {
row, err := r.reader.Read()
if err != nil {
if !errors.Is(err, io.EOF) {
return nil, err
}
}
r.names = make([]string, len(row))
for i, col := range row {
r.names[i] = col
if col == "" {
r.names[i] = "c" + strconv.Itoa(i+1)
}
}
preReadN--
}
// Pre-read and stored in slices.
for n := 0; n < preReadN; n++ {
row, err := r.reader.Read()
if err != nil {
if !errors.Is(err, io.EOF) {
return r, err
}
r.setColumnType()
debug.Printf(err.Error())
return r, nil
}
rows := make([]string, len(row))
for i, col := range row {
rows[i] = col
// If there are more columns than header, add column names.
if len(r.names) < i+1 {
r.names = append(r.names, "c"+strconv.Itoa(i+1))
}
}
r.preRead = append(r.preRead, rows)
}
r.setColumnType()
return r, nil
}
func NewTSVReader(reader io.Reader, opts *ReadOpts) (*CSVReader, error) {
opts.InDelimiter = "\t"
return NewCSVReader(reader, opts)
}
func NewPSVReader(reader io.Reader, opts *ReadOpts) (*CSVReader, error) {
opts.InDelimiter = "|"
return NewCSVReader(reader, opts)
}
func (r *CSVReader) setColumnType() {
if r.names == nil {
return
}
r.types = make([]string, len(r.names))
for i := 0; i < len(r.names); i++ {
r.types[i] = DefaultDBType
}
}
func delimiter(sepString string) (rune, error) {
if sepString == "" {
return 0, nil
}
sepRunes, err := strconv.Unquote(`'` + sepString + `'`)
if err != nil {
return ',', fmt.Errorf("can not get separator: %w:\"%s\"", err, sepString)
}
sepRune := ([]rune(sepRunes))[0]
return sepRune, err
}
// Names returns column names.
func (r *CSVReader) Names() ([]string, error) {
if len(r.names) == 0 {
return r.names, ErrNoRows
}
return r.names, nil
}
// Types returns column types.
// All CSV types return the DefaultDBType.
func (r *CSVReader) Types() ([]string, error) {
if len(r.types) == 0 {
return r.types, ErrNoRows
}
return r.types, nil
}
// PreReadRow is returns only columns that store preread rows.
func (r *CSVReader) PreReadRow() [][]any {
rowNum := len(r.preRead)
rows := make([][]any, rowNum)
for n := 0; n < rowNum; n++ {
rows[n] = make([]any, len(r.names))
for i, f := range r.preRead[n] {
rows[n][i] = f
if r.needNULL {
rows[n][i] = replaceNULL(r.inNULL, rows[n][i])
}
}
}
return rows
}
// ReadRow is read the rest of the row.
func (r *CSVReader) ReadRow(row []any) ([]any, error) {
if r.limitRead {
return nil, io.EOF
}
record, err := r.reader.Read()
if err != nil {
return row, err
}
for i := 0; len(row) > i; i++ {
if len(record) > i {
row[i] = record[i]
if r.needNULL {
row[i] = replaceNULL(r.inNULL, row[i])
}
} else {
row[i] = nil
}
}
return row, nil
}
package trdsql
import (
"errors"
"io"
"github.com/noborus/guesswidth"
)
// GWReader provides methods of the Reader interface.
type GWReader struct {
reader *guesswidth.GuessWidth
scanNum int
preRead int
inNULL string
names []string
types []string
limitRead bool
needNULL bool
}
// NewGWReader returns GWReader and error.
func NewGWReader(reader io.Reader, opts *ReadOpts) (*GWReader, error) {
r := &GWReader{}
r.reader = guesswidth.NewReader(reader)
r.reader.TrimSpace = true
r.limitRead = opts.InLimitRead
r.reader.Header = opts.InSkip
r.scanNum = 1000
r.needNULL = opts.InNeedNULL
r.inNULL = opts.InNULL
r.preRead = opts.InPreRead
if r.preRead > r.scanNum {
r.scanNum = r.preRead
}
r.reader.Scan(r.scanNum)
for i := 0; i < opts.InSkip; i++ {
if _, err := r.reader.Read(); err != nil {
if errors.Is(err, io.EOF) {
return r, nil
}
}
}
names, err := r.reader.Read()
if err != nil {
if errors.Is(err, io.EOF) {
return r, nil
}
return nil, err
}
r.names = names
r.setColumnType()
return r, nil
}
func (r *GWReader) setColumnType() {
if r.names == nil {
return
}
r.types = make([]string, len(r.names))
for i := 0; i < len(r.names); i++ {
r.types[i] = DefaultDBType
}
}
// Names returns column names.
func (r *GWReader) Names() ([]string, error) {
return r.names, nil
}
// Types returns column types.
// All GW types return the DefaultDBType.
func (r *GWReader) Types() ([]string, error) {
return r.types, nil
}
// PreReadRow is returns only columns that store preread rows.
func (r *GWReader) PreReadRow() [][]any {
rows := make([][]any, r.preRead)
for n := 0; n < r.preRead; n++ {
record, err := r.reader.Read()
if err != nil {
return rows
}
rows[n] = make([]any, len(r.names))
for i := 0; i < len(r.names); i++ {
rows[n][i] = record[i]
if r.needNULL {
rows[n][i] = replaceNULL(r.inNULL, rows[n][i])
}
}
}
return rows
}
// ReadRow is read the rest of the row.
func (r *GWReader) ReadRow(row []any) ([]any, error) {
if r.limitRead {
return nil, io.EOF
}
record, err := r.reader.Read()
if err != nil {
return row, err
}
for i := 0; i < len(row); i++ {
row[i] = record[i]
if r.needNULL {
row[i] = replaceNULL(r.inNULL, row[i])
}
}
return row, nil
}
package trdsql
// Convert JSON to a table.
// Supports the following JSON container types.
// * Array ([{c1: 1}, {c1: 2}, {c1: 3}])
// * Multiple JSON ({c1: 1}\n {c1: 2}\n {c1: 3}\n)
// Make a table from json
// or make the result of json filter by jq.
import (
"encoding/json"
"errors"
"fmt"
"io"
"log"
"github.com/itchyny/gojq"
)
// JSONReader provides methods of the Reader interface.
type JSONReader struct {
reader *json.Decoder
query *gojq.Query
already map[string]bool
inNULL string
preRead []map[string]any
names []string
types []string
limitRead bool
needNULL bool
}
// NewJSONReader returns JSONReader and error.
func NewJSONReader(reader io.Reader, opts *ReadOpts) (*JSONReader, error) {
r := &JSONReader{}
r.reader = json.NewDecoder(reader)
r.reader.UseNumber()
r.already = make(map[string]bool)
var top any
if opts.InJQuery != "" {
str := trimQuoteAll(opts.InJQuery)
query, err := gojq.Parse(str)
if err != nil {
return nil, fmt.Errorf("%w gojq:(%s)", err, opts.InJQuery)
}
r.query = query
}
r.limitRead = opts.InLimitRead
r.needNULL = opts.InNeedNULL
r.inNULL = opts.InNULL
for i := 0; i < opts.InPreRead; i++ {
if err := r.reader.Decode(&top); err != nil {
if !errors.Is(err, io.EOF) {
return r, fmt.Errorf("%w: %s", ErrInvalidJSON, err)
}
debug.Printf(err.Error())
return r, nil
}
if r.query != nil {
if err := r.jqueryRun(top); err != nil {
return nil, err
}
return r, nil
}
if err := r.readAhead(top); err != nil {
return nil, err
}
}
return r, nil
}
// Names returns column names.
func (r *JSONReader) Names() ([]string, error) {
return r.names, nil
}
// Types returns column types.
// All JSON types return the DefaultDBType.
func (r *JSONReader) Types() ([]string, error) {
r.types = make([]string, len(r.names))
for i := 0; i < len(r.names); i++ {
r.types[i] = DefaultDBType
}
return r.types, nil
}
// readAhead parses the top level of the JSON and stores it in preRead.
func (r *JSONReader) readAhead(top any) error {
switch m := top.(type) {
case []any:
// []
r.preRead = make([]map[string]any, 0, len(m))
if r.reader.More() {
pre, names, err := r.etcRow(m)
if err != nil {
return err
}
r.appendNames(names)
r.preRead = append(r.preRead, pre)
return nil
}
for _, v := range m {
pre, names, err := r.topLevel(v)
if err != nil {
return err
}
r.appendNames(names)
r.preRead = append(r.preRead, pre)
}
return nil
default:
pre, names, err := r.topLevel(m)
if err != nil {
return err
}
r.appendNames(names)
r.preRead = append(r.preRead, pre)
}
return nil
}
// appendNames adds multiple names for the argument to be unique.
func (r *JSONReader) appendNames(names []string) {
for _, name := range names {
if !r.already[name] {
r.already[name] = true
r.names = append(r.names, name)
}
}
}
func (r *JSONReader) topLevel(top any) (map[string]any, []string, error) {
switch obj := top.(type) {
case map[string]any:
return r.objectRow(obj)
default:
return r.etcRow(obj)
}
}
// PreReadRow is returns only columns that store preRead rows.
// One json (not jsonl) returns all rows with preRead.
func (r *JSONReader) PreReadRow() [][]any {
rows := make([][]any, len(r.preRead))
for n, v := range r.preRead {
rows[n] = make([]any, len(r.names))
for i := range r.names {
rows[n][i] = v[r.names[i]]
if r.needNULL {
rows[n][i] = replaceNULL(r.inNULL, rows[n][i])
}
}
}
return rows
}
// ReadRow is read the rest of the row.
// Only jsonl requires ReadRow in json.
func (r *JSONReader) ReadRow(row []any) ([]any, error) {
if r.limitRead {
return nil, io.EOF
}
var data any
if err := r.reader.Decode(&data); err != nil {
return nil, err
}
if r.query != nil {
return r.jqueryRunJsonl(row, data)
}
return r.rowParse(row, data), nil
}
func (r *JSONReader) rowParse(row []any, jsonRow any) []any {
switch m := jsonRow.(type) {
case map[string]any:
for i := range r.names {
row[i] = r.jsonString(m[r.names[i]])
}
default:
for i := range r.names {
row[i] = nil
}
row[0] = r.jsonString(jsonRow)
}
return row
}
func (r *JSONReader) objectRow(obj map[string]any) (map[string]any, []string, error) {
// {"a":"b"} object
names := make([]string, 0, len(obj))
row := make(map[string]any)
for k, v := range obj {
names = append(names, k)
row[k] = r.jsonString(v)
}
return row, names, nil
}
func (r *JSONReader) etcRow(val any) (map[string]any, []string, error) {
// ex. array array
// [["a"],
// ["b"]]
var names []string
k := "c1"
names = append(names, k)
row := make(map[string]any)
row[k] = r.jsonString(val)
return row, names, nil
}
// jqueryRun is a gojq.Run for json.
func (r *JSONReader) jqueryRun(top any) error {
iter := r.query.Run(top)
for {
v, ok := iter.Next()
if !ok {
break
}
if err, ok := v.(error); ok {
return fmt.Errorf("%w gojq:(%s) ", err, r.query)
}
if err := r.readAhead(v); err != nil {
return err
}
}
return nil
}
// jqueryRunJsonl gojq.Run for rows of jsonl.
func (r *JSONReader) jqueryRunJsonl(row []any, jsonRow any) ([]any, error) {
iter := r.query.Run(jsonRow)
for {
v, ok := iter.Next()
if !ok {
break
}
if err, ok := v.(error); ok {
debug.Printf("%s gojq: %s", err.Error(), r.query)
continue
}
row = r.rowParse(row, v)
}
return row, nil
}
// jsonString returns the string of the argument.
func (r *JSONReader) jsonString(val any) any {
var str string
switch val.(type) {
case nil:
return nil
case map[string]any, []any:
b, err := json.Marshal(val)
if err != nil {
log.Printf("ERROR: jsonString:%s", err)
}
str = ValString(b)
default:
str = ValString(val)
}
if r.needNULL {
return replaceNULL(r.inNULL, str)
}
return str
}
package trdsql
import (
"bufio"
"errors"
"io"
"strings"
)
// LTSVReader provides methods of the Reader interface.
type LTSVReader struct {
reader *bufio.Reader
delimiter string
inNULL string
preRead []map[string]string
names []string
types []string
limitRead bool
needNULL bool
}
// NewLTSVReader returns LTSVReader and error.
func NewLTSVReader(reader io.Reader, opts *ReadOpts) (*LTSVReader, error) {
r := <SVReader{}
r.reader = bufio.NewReader(reader)
r.delimiter = "\t"
if opts.InSkip > 0 {
skipRead(r, opts.InSkip)
}
r.limitRead = opts.InLimitRead
r.needNULL = opts.InNeedNULL
r.inNULL = opts.InNULL
names := map[string]bool{}
for i := 0; i < opts.InPreRead; i++ {
row, keys, err := r.read()
if err != nil {
if !errors.Is(err, io.EOF) {
return r, err
}
r.setColumnType()
debug.Printf(err.Error())
return r, nil
}
// Add only unique column names.
for k := 0; k < len(keys); k++ {
if !names[keys[k]] {
names[keys[k]] = true
r.names = append(r.names, keys[k])
}
}
r.preRead = append(r.preRead, row)
}
r.setColumnType()
return r, nil
}
func (r *LTSVReader) setColumnType() {
if r.names == nil {
return
}
r.types = make([]string, len(r.names))
for i := 0; i < len(r.names); i++ {
r.types[i] = DefaultDBType
}
}
// Names returns column names.
func (r *LTSVReader) Names() ([]string, error) {
return r.names, nil
}
// Types returns column types.
// All LTSV types return the DefaultDBType.
func (r *LTSVReader) Types() ([]string, error) {
return r.types, nil
}
// PreReadRow is returns only columns that store preread rows.
func (r *LTSVReader) PreReadRow() [][]any {
rowNum := len(r.preRead)
rows := make([][]any, rowNum)
for n := 0; n < rowNum; n++ {
rows[n] = make([]any, len(r.names))
for i := range r.names {
rows[n][i] = r.preRead[n][r.names[i]]
if r.needNULL {
rows[n][i] = replaceNULL(r.inNULL, rows[n][i])
}
}
}
return rows
}
// ReadRow is read the rest of the row.
func (r *LTSVReader) ReadRow(row []any) ([]any, error) {
if r.limitRead {
return nil, io.EOF
}
record, _, err := r.read()
if err != nil {
return row, err
}
for i, name := range r.names {
row[i] = record[name]
if r.needNULL {
row[i] = replaceNULL(r.inNULL, row[i])
}
}
return row, nil
}
func (r *LTSVReader) read() (map[string]string, []string, error) {
line, err := r.readline()
if err != nil {
return nil, nil, err
}
columns := strings.Split(line, r.delimiter)
lvs := make(map[string]string)
keys := make([]string, 0, len(columns))
for _, column := range columns {
kv := strings.SplitN(column, ":", 2)
if len(kv) != 2 {
return nil, nil, ErrInvalidColumn
}
lvs[kv[0]] = kv[1]
keys = append(keys, kv[0])
}
return lvs, keys, nil
}
func (r *LTSVReader) readline() (string, error) {
var builder strings.Builder
for {
line, isPrefix, err := r.reader.ReadLine()
if err != nil {
return "", err
}
builder.Write(line)
if isPrefix {
continue
}
str := strings.TrimSpace(builder.String())
if len(str) != 0 {
return str, nil
}
builder.Reset()
}
}
package trdsql
import (
"strconv"
)
// rowNumberReader is a Reader that adds a row number column to the input.
type rowNumberReader struct {
reader Reader
originRow []any
lineCount int
}
// newRowNumberReader creates a new rowNumberReader.
func newRowNumberReader(r Reader) *rowNumberReader {
columnNum := 1
names, err := r.Names()
if err == nil {
columnNum = len(names)
}
originRow := make([]any, columnNum)
return &rowNumberReader{
reader: r,
originRow: originRow,
lineCount: 0,
}
}
// Names returns column names with an additional row number column.
func (r *rowNumberReader) Names() ([]string, error) {
number := "num"
names, err := r.reader.Names()
if err != nil {
return nil, err
}
for i := 0; i < len(names)+1; i++ {
orig := number
for _, name := range names {
if number == name {
number = orig + strconv.Itoa(i)
i++
continue
}
}
}
return append([]string{number}, names...), nil
}
// Types returns column types with an additional row number column.
func (r *rowNumberReader) Types() ([]string, error) {
types, err := r.reader.Types()
if err != nil {
return nil, err
}
return append([]string{"int"}, types...), nil
}
// PreReadRow returns pre-read rows with an additional row number column.
func (r *rowNumberReader) PreReadRow() [][]any {
preReadRows := r.reader.PreReadRow()
for i := range preReadRows {
preReadRows[i] = append([]any{r.lineCount + i + 1}, preReadRows[i]...)
}
r.lineCount += len(preReadRows)
return preReadRows
}
// ReadRow reads the rest of the row with an additional row number column.
func (r *rowNumberReader) ReadRow(row []any) ([]any, error) {
var err error
r.lineCount++
r.originRow, err = r.reader.ReadRow(r.originRow)
if err != nil {
return nil, err
}
if len(r.originRow) == 0 {
return nil, nil
}
return append([]any{r.lineCount}, r.originRow...), nil
}
package trdsql
import (
"fmt"
"io"
"reflect"
)
// SliceReader is a structure for reading tabular data in memory.
// It can be used as the trdsql reader interface.
type SliceReader struct {
tableName string
names []string
types []string
data [][]any
}
// NewSliceReader takes a tableName and tabular data in memory
// and returns SliceReader.
// The tabular data that can be received is
// a one-dimensional array,
// a two-dimensional array,
// a map,
// and an array of structures.
func NewSliceReader(tableName string, args any) *SliceReader {
val := reflect.ValueOf(args)
if val.Kind() == reflect.Ptr {
val = reflect.Indirect(val)
}
// One-dimensional
switch val.Kind() {
case reflect.Map:
return mapReader(tableName, val)
case reflect.Struct:
return structReader(tableName, val)
case reflect.Slice:
return sliceReader(tableName, val)
default:
single := val.Interface()
data := [][]any{
{single},
}
names := []string{"c1"}
types := []string{typeToDBType(val.Kind())}
return &SliceReader{
tableName: tableName,
names: names,
types: types,
data: data,
}
}
}
func mapReader(tableName string, val reflect.Value) *SliceReader {
val = reflect.Indirect(val)
names := []string{"c1", "c2"}
keyType := val.MapKeys()[0].Kind()
valType := val.MapIndex(val.MapKeys()[0]).Kind()
types := []string{typeToDBType(keyType), typeToDBType(valType)}
data := make([][]any, 0)
for _, e := range val.MapKeys() {
data = append(data, []any{e.Interface(), val.MapIndex(e).Interface()})
}
return &SliceReader{
tableName: tableName,
names: names,
types: types,
data: data,
}
}
func structReader(tableName string, val reflect.Value) *SliceReader {
t := val.Type()
columnNum := t.NumField()
names := make([]string, columnNum)
types := make([]string, columnNum)
for i := 0; i < columnNum; i++ {
f := t.Field(i)
names[i] = f.Name
types[i] = typeToDBType(f.Type.Kind())
}
single := make([]any, t.NumField())
for j := 0; j < t.NumField(); j++ {
single[j] = fmt.Sprintf("%v", val.Field(j))
}
data := [][]any{
single,
}
return &SliceReader{
tableName: tableName,
names: names,
types: types,
data: data,
}
}
func sliceReader(tableName string, val reflect.Value) *SliceReader {
if val.Len() == 0 {
return &SliceReader{
tableName: tableName,
names: []string{"c1"},
types: []string{"text"},
data: nil,
}
}
switch val.Index(0).Kind() {
case reflect.Struct:
// {{ id: 1, name: "test"},{ id: 2, name: "test2"}}
return structSliceReader(tableName, val)
case reflect.Slice:
// {{1, "test"},{2, "test2"}}
return sliceSliceReader(tableName, val)
default:
// {{"a", "b", "c"}}
return interfaceSliceReader(tableName, val)
}
}
func structSliceReader(tableName string, val reflect.Value) *SliceReader {
length := val.Len()
t := val.Index(0).Type()
columnNum := t.NumField()
names := make([]string, columnNum)
types := make([]string, columnNum)
for i := 0; i < columnNum; i++ {
f := t.Field(i)
names[i] = f.Name
types[i] = typeToDBType(f.Type.Kind())
}
data := make([][]any, 0)
for i := 0; i < length; i++ {
rows := val.Index(i)
r := make([]any, rows.NumField())
for j := 0; j < rows.NumField(); j++ {
r[j] = fmt.Sprintf("%v", rows.Field(j))
}
data = append(data, r)
}
return &SliceReader{
tableName: tableName,
names: names,
types: types,
data: data,
}
}
func sliceSliceReader(tableName string, val reflect.Value) *SliceReader {
length := val.Len()
col := val.Index(0)
columnNum := col.Len()
names := make([]string, columnNum)
types := make([]string, columnNum)
for i := 0; i < columnNum; i++ {
names[i] = fmt.Sprintf("c%d", i+1)
colType := reflect.ValueOf(col.Index(i).Interface()).Kind()
types[i] = typeToDBType(colType)
}
data := make([][]any, 0)
for i := 0; i < length; i++ {
data = append(data, val.Index(i).Interface().([]any))
}
return &SliceReader{
tableName: tableName,
names: names,
types: types,
data: data,
}
}
func interfaceSliceReader(tableName string, val reflect.Value) *SliceReader {
v := val.Index(0).Interface()
length := val.Len()
t := reflect.ValueOf(v)
names := []string{"c1"}
types := []string{typeToDBType(t.Kind())}
data := make([][]any, length)
for i := 0; i < length; i++ {
data[i] = []any{val.Index(i).Interface()}
}
return &SliceReader{
tableName: tableName,
names: names,
types: types,
data: data,
}
}
// In sliceReader, only int type is passed to the database as int type.
func typeToDBType(t reflect.Kind) string {
switch t {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return "int"
default:
return DefaultDBType
}
}
// TableName returns Table name.
func (r *SliceReader) TableName() (string, error) {
return r.tableName, nil
}
// Names returns column names.
func (r *SliceReader) Names() ([]string, error) {
return r.names, nil
}
// Types returns column types.
func (r *SliceReader) Types() ([]string, error) {
return r.types, nil
}
// PreReadRow is returns entity of the data.
func (r *SliceReader) PreReadRow() [][]any {
return r.data
}
// ReadRow only returns EOF.
func (r *SliceReader) ReadRow(row []any) ([]any, error) {
return nil, io.EOF
}
package trdsql
import (
"errors"
"io"
"strconv"
"github.com/noborus/tbln"
)
// TBLNRead provides methods of the Reader interface.
type TBLNRead struct {
reader tbln.Reader
inNULL string
preRead [][]any
limitRead bool
needNULL bool
}
// NewTBLNReader returns TBLNRead and error.
func NewTBLNReader(reader io.Reader, opts *ReadOpts) (*TBLNRead, error) {
r := &TBLNRead{}
r.reader = tbln.NewReader(reader)
r.limitRead = opts.InLimitRead
r.needNULL = opts.InNeedNULL
r.inNULL = opts.InNULL
rec, err := r.reader.ReadRow()
if err != nil {
if !errors.Is(err, io.EOF) {
return r, err
}
debug.Printf(err.Error())
return r, nil
}
// SetNames if there is no names header.
d := r.reader.GetDefinition()
names := d.Names()
if len(names) == 0 {
names = make([]string, len(rec))
for i := range rec {
names[i] = "c" + strconv.Itoa(i+1)
}
if err := d.SetNames(names); err != nil {
return r, err
}
}
// SetTypes if there is no types header.
types := d.Types()
if len(types) == 0 {
types = make([]string, len(rec))
for i := range rec {
types[i] = DefaultDBType
}
if err := d.SetTypes(types); err != nil {
return r, err
}
}
r.preRead = make([][]any, 0, opts.InPreRead)
r.preRead = append(r.preRead, r.recToRow(rec))
for n := 1; n < opts.InPreRead; n++ {
rec, err := r.reader.ReadRow()
if err != nil {
if !errors.Is(err, io.EOF) {
return r, err
}
debug.Printf(err.Error())
return r, nil
}
r.preRead = append(r.preRead, r.recToRow(rec))
}
return r, nil
}
// Names returns column names.
func (r *TBLNRead) Names() ([]string, error) {
reader := r.reader
if reader == nil {
return nil, ErrNonDefinition
}
d := reader.GetDefinition()
return d.Names(), nil
}
// Types returns column types.
func (r *TBLNRead) Types() ([]string, error) {
reader := r.reader
if reader == nil {
return nil, ErrNonDefinition
}
d := reader.GetDefinition()
return d.Types(), nil
}
// PreReadRow is returns only columns that store preread rows.
func (r *TBLNRead) PreReadRow() [][]any {
return r.preRead
}
// ReadRow is read the rest of the row.
func (r *TBLNRead) ReadRow(row []any) ([]any, error) {
if r.limitRead {
return nil, io.EOF
}
rec, err := r.reader.ReadRow()
if err != nil {
return row, err
}
row = r.recToRow(rec)
return row, nil
}
func (r *TBLNRead) recToRow(rec []string) []any {
row := make([]any, len(rec))
for i, c := range rec {
if c != "" {
row[i] = c
}
if r.needNULL {
if row[i] == r.inNULL {
row[i] = nil
}
}
}
return row
}
package trdsql
import (
"bufio"
"io"
"strings"
)
// TextReader provides a reader for text format.
type TextReader struct {
reader *bufio.Reader
num int
maxNum int
}
// NewTextReader returns a new TextReader.
func NewTextReader(reader io.Reader, opts *ReadOpts) (*TextReader, error) {
r := &TextReader{
reader: bufio.NewReader(reader),
}
if opts.InSkip > 0 {
skipRead(r, opts.InSkip)
}
if opts.InLimitRead {
r.maxNum = opts.InPreRead
}
return r, nil
}
// Names returns column names.
func (r *TextReader) Names() ([]string, error) {
return []string{"text"}, nil
}
// Types returns column types.
func (r *TextReader) Types() ([]string, error) {
return []string{"text"}, nil
}
// PreReadRow returns pre-read rows.
func (r *TextReader) PreReadRow() [][]any {
return nil
}
// ReadRow reads a row.
func (r *TextReader) ReadRow([]any) ([]any, error) {
var builder strings.Builder
for {
if r.maxNum > 0 && r.num >= r.maxNum {
return []any{""}, io.EOF
}
line, isPrefix, err := r.reader.ReadLine()
if err != nil {
return []any{""}, err
}
builder.Write(line)
if isPrefix {
continue
}
r.num++
return []any{builder.String()}, nil
}
}
package trdsql
import (
"bytes"
"errors"
"fmt"
"io"
"log"
"strings"
"github.com/goccy/go-yaml"
"github.com/itchyny/gojq"
)
// YAMLReader provides methods of the Reader interface.
type YAMLReader struct {
reader *yaml.Decoder
query *gojq.Query
already map[string]bool
inNULL string
preRead []map[string]any
names []string
types []string
limitRead bool
needNULL bool
}
// NewYAMLReader returns YAMLReader and error.
func NewYAMLReader(reader io.Reader, opts *ReadOpts) (*YAMLReader, error) {
r := &YAMLReader{}
query, err := jqParse(opts.InJQuery)
if err != nil {
return nil, err
}
r.query = query
r.reader = yaml.NewDecoder(reader)
r.already = make(map[string]bool)
if err := r.yamlParse(opts); err != nil {
return nil, err
}
return r, nil
}
// jqParse parses a string and returns a *gojq.Query.
func jqParse(q string) (*gojq.Query, error) {
if q == "" {
return nil, nil
}
str := trimQuoteAll(q)
query, err := gojq.Parse(str)
if err != nil {
return nil, fmt.Errorf("%w gojq:(%s)", err, str)
}
return query, nil
}
func (r *YAMLReader) wrapDecode(v any) (err error) {
defer func() {
if rec := recover(); rec != nil {
err = fmt.Errorf("%s", rec)
}
}()
err = r.reader.Decode(v)
return
}
// yamlParse parses YAML and stores it in preRead.
func (r *YAMLReader) yamlParse(opts *ReadOpts) error {
r.limitRead = opts.InLimitRead
r.needNULL = opts.InNeedNULL
r.inNULL = opts.InNULL
var top any
for i := 0; i < opts.InPreRead; i++ {
if err := r.wrapDecode(&top); err != nil {
if !errors.Is(err, io.EOF) {
return fmt.Errorf("%w: %s", ErrInvalidYAML, err)
}
debug.Printf(err.Error())
return nil
}
if r.query != nil {
if err := r.jquery(top); err != nil {
return err
}
return nil
}
if err := r.readAhead(top); err != nil {
return err
}
}
return nil
}
// jquery parses the top level of the YAML and stores it in preRead.
func (r *YAMLReader) jquery(top any) error {
iter := r.query.Run(top)
for {
v, ok := iter.Next()
if !ok {
break
}
if err, ok := v.(error); ok {
return fmt.Errorf("%w gojq:(%s) ", err, r.query)
}
if err := r.readAhead(v); err != nil {
return err
}
}
return nil
}
// Names returns column names.
func (r *YAMLReader) Names() ([]string, error) {
return r.names, nil
}
// Types returns column types.
// All YAML types return the DefaultDBType.
func (r *YAMLReader) Types() ([]string, error) {
r.types = make([]string, len(r.names))
for i := 0; i < len(r.names); i++ {
r.types[i] = DefaultDBType
}
return r.types, nil
}
// readAhead parses the top level of the YAML and stores it in preRead.
func (r *YAMLReader) readAhead(top any) error {
switch m := top.(type) {
case []any:
for _, v := range m {
pre, names, err := r.topLevel(v)
if err != nil {
return err
}
r.appendNames(names)
r.preRead = append(r.preRead, pre)
}
case map[string]any:
pre, names, err := r.topLevel(m)
if err != nil {
return err
}
r.appendNames(names)
r.preRead = append(r.preRead, pre)
case yaml.MapSlice: // YAML object (key: value). (if UseOrderedMap is enabled).
pre, names, err := r.objectMapSlice(m)
if err != nil {
return err
}
r.appendNames(names)
r.preRead = append(r.preRead, pre)
default:
pre, names, err := r.etcRow(m)
if err != nil {
return err
}
r.appendNames(names)
r.preRead = append(r.preRead, pre)
}
return nil
}
// appendNames adds multiple names for the argument to be unique.
func (r *YAMLReader) appendNames(names []string) {
for _, name := range names {
if !r.already[name] {
r.already[name] = true
r.names = append(r.names, name)
}
}
}
func (r *YAMLReader) topLevel(top any) (map[string]any, []string, error) {
switch obj := top.(type) {
case map[string]any:
return r.objectRow(obj)
case yaml.MapSlice:
return r.objectMapSlice(obj)
default:
return r.etcRow(obj)
}
}
// PreReadRow is returns only columns that store preRead rows.
// One YAML (not YAMLl) returns all rows with preRead.
func (r *YAMLReader) PreReadRow() [][]any {
rows := make([][]any, len(r.preRead))
for n, v := range r.preRead {
rows[n] = make([]any, len(r.names))
for i := range r.names {
rows[n][i] = v[r.names[i]]
}
}
return rows
}
// ReadRow is read the rest of the row.
// Only YAMLl requires ReadRow in YAML.
func (r *YAMLReader) ReadRow(row []any) ([]any, error) {
if r.limitRead {
return nil, io.EOF
}
var data any
if err := r.reader.Decode(&data); err != nil {
return nil, err
}
v := r.rowParse(row, data)
return v, nil
}
func (r *YAMLReader) rowParse(row []any, yamlRow any) []any {
switch m := yamlRow.(type) {
case map[string]any:
for i := range r.names {
row[i] = r.toString(m[r.names[i]])
}
default:
for i := range r.names {
row[i] = nil
}
row[0] = r.toString(yamlRow)
}
return row
}
// objectRow returns a map of the YAML object and the column names.
func (r *YAMLReader) objectRow(obj map[string]any) (map[string]any, []string, error) {
names := make([]string, 0, len(obj))
row := make(map[string]any)
for k, v := range obj {
names = append(names, k)
row[k] = r.toString(v)
}
return row, names, nil
}
// objectMapSlice returns a yaml.MapSlice of the YAML object and the column names.
func (r *YAMLReader) objectMapSlice(obj yaml.MapSlice) (map[string]any, []string, error) {
names := make([]string, 0, len(obj))
row := make(map[string]any)
for _, item := range obj {
key := item.Key.(string)
names = append(names, key)
row[key] = r.toString(item.Value)
}
return row, names, nil
}
// etcRow returns 1 element with column name c1.
func (r *YAMLReader) etcRow(val any) (map[string]any, []string, error) {
var names []string
k := "c1"
names = append(names, k)
row := make(map[string]any)
row[k] = r.toString(val)
return row, names, nil
}
// toString returns a string representation of val.
// It will be YAML if val is a struct or map, otherwise it will be a string representation of val.
func (r *YAMLReader) toString(val any) any {
var str string
switch t := val.(type) {
case nil:
return nil
case map[string]any, []yaml.MapSlice, []any:
b, err := yaml.Marshal(val)
if err != nil {
log.Printf("ERROR: YAMLString:%s", err)
}
str = yamlToStr(b)
case []byte:
str = yamlToStr(t)
case string:
str = yamlToStr([]byte(t))
default:
str = ValString(t)
}
// Remove the last newline.
str = strings.TrimRight(str, "\n")
if r.needNULL {
return replaceNULL(r.inNULL, str)
}
return str
}
// yamlToStr converts marshalled YAML to string.
// Values that can be converted to JSON should be JSON.
func yamlToStr(buf []byte) string {
if !bytes.Contains(buf, []byte("\n")) {
return ValString(buf)
}
// Convert to JSON if it's a YAML element.
j, err := yaml.YAMLToJSON(buf)
if err != nil {
return ValString(buf)
}
return ValString(j)
}
package trdsql
import (
"bufio"
"strings"
"unicode"
"unicode/utf8"
)
// CSVWriter provides methods of the Writer interface.
type CSVWriter struct {
writer *bufio.Writer
needQuotes string
endLine string
outNULL string
outDelimiter rune
outQuote rune
outHeader bool
outAllQuote bool
outUseCRLF bool
needNULL bool
}
// NewCSVWriter returns CSVWriter.
func NewCSVWriter(writeOpts *WriteOpts) *CSVWriter {
w := &CSVWriter{}
w.writer = bufio.NewWriter(writeOpts.OutStream)
d, err := delimiter(writeOpts.OutDelimiter)
if err != nil {
debug.Printf("%s\n", err)
}
w.outDelimiter = d
w.outQuote = 0
if len(writeOpts.OutQuote) > 0 {
w.outQuote = ([]rune(writeOpts.OutQuote))[0]
}
w.outAllQuote = writeOpts.OutAllQuotes
w.outUseCRLF = writeOpts.OutUseCRLF
w.outHeader = writeOpts.OutHeader
w.needQuotes = string(w.outDelimiter) + string(w.outQuote) + "\r\n"
w.endLine = "\n"
if writeOpts.OutUseCRLF {
w.endLine = "\r\n"
}
w.needNULL = writeOpts.OutNeedNULL
w.outNULL = writeOpts.OutNULL
return w
}
// PreWrite is output of header and preparation.
func (w *CSVWriter) PreWrite(columns []string, types []string) error {
if !w.outHeader {
return nil
}
for n, column := range columns {
if n > 0 {
if _, err := w.writer.WriteRune(w.outDelimiter); err != nil {
return err
}
}
if err := w.writeColumnString(column); err != nil {
return err
}
}
_, err := w.writer.WriteString(w.endLine)
return err
}
// WriteRow is row write.
func (w *CSVWriter) WriteRow(values []any, _ []string) error {
for n, column := range values {
if n > 0 {
if _, err := w.writer.WriteRune(w.outDelimiter); err != nil {
return err
}
}
if err := w.writeColumn(column); err != nil {
return err
}
}
_, err := w.writer.WriteString(w.endLine)
return err
}
func (w *CSVWriter) writeColumn(column any) error {
if column == nil {
var err error
if w.needNULL {
_, err = w.writer.WriteString(w.outNULL)
} else {
_, err = w.writer.WriteString("")
}
return err
}
str := ValString(column)
return w.writeColumnString(str)
}
func (w *CSVWriter) writeColumnString(column string) error {
if !w.fieldNeedsQuotes(column) {
_, err := w.writer.WriteString(column)
return err
}
if _, err := w.writer.WriteRune(w.outQuote); err != nil {
return err
}
var err error
for _, r1 := range column {
switch r1 {
case w.outQuote:
_, err = w.writer.WriteString(string([]rune{w.outQuote, w.outQuote}))
case '\r':
if !w.outUseCRLF {
err = w.writer.WriteByte('\r')
}
case '\n':
if w.outUseCRLF {
_, err = w.writer.WriteString("\r\n")
} else {
err = w.writer.WriteByte('\n')
}
default:
_, err = w.writer.WriteRune(r1)
}
if err != nil {
return err
}
}
_, err = w.writer.WriteRune(w.outQuote)
return err
}
func (w *CSVWriter) fieldNeedsQuotes(field string) bool {
if w.outAllQuote {
return true
}
if field == "" {
return false
}
if field == `\.` || strings.ContainsAny(field, w.needQuotes) {
return true
}
r1, _ := utf8.DecodeRuneInString(field)
return unicode.IsSpace(r1)
}
// PostWrite is flush.
func (w *CSVWriter) PostWrite() error {
return w.writer.Flush()
}
package trdsql
import (
"encoding/hex"
"encoding/json"
"unicode/utf8"
"github.com/iancoleman/orderedmap"
)
// JSONWriter provides methods of the Writer interface.
type JSONWriter struct {
writer *json.Encoder
outNULL string
results []*orderedmap.OrderedMap
needNULL bool
}
// NewJSONWriter returns JSONWriter.
func NewJSONWriter(writeOpts *WriteOpts) *JSONWriter {
w := &JSONWriter{}
w.writer = json.NewEncoder(writeOpts.OutStream)
w.writer.SetIndent("", " ")
w.needNULL = writeOpts.OutNeedNULL
w.outNULL = writeOpts.OutNULL
return w
}
// PreWrite is area preparation.
func (w *JSONWriter) PreWrite(columns []string, types []string) error {
w.results = make([]*orderedmap.OrderedMap, 0)
return nil
}
// WriteRow is Addition to array.
func (w *JSONWriter) WriteRow(values []any, columns []string) error {
m := orderedmap.New()
for i, col := range values {
m.Set(columns[i], compatibleJSON(col, w.needNULL, w.outNULL))
}
w.results = append(w.results, m)
return nil
}
// CompatibleJSON converts the value to a JSON-compatible value.
func compatibleJSON(v any, needNULL bool, outNULL string) any {
switch t := v.(type) {
case []byte:
if isJSON(t) {
return json.RawMessage(t)
}
if ok := utf8.Valid(t); ok {
return string(t)
}
return `\x` + hex.EncodeToString(t)
case string:
if isJSON([]byte(t)) {
return json.RawMessage(t)
}
return v
default:
if needNULL {
return outNULL
}
return v
}
}
// isJSON returns true if the byte array is JSON.
func isJSON(s []byte) bool {
if len(s) == 0 {
return false
}
// Except for JSONArray or JSONObject
if s[0] != '[' && s[0] != '{' {
return false
}
var js any
err := json.Unmarshal(s, &js)
return err == nil
}
// PostWrite is actual output.
func (w *JSONWriter) PostWrite() error {
return w.writer.Encode(w.results)
}
package trdsql
import (
"encoding/json"
"github.com/iancoleman/orderedmap"
)
// JSONLWriter provides methods of the Writer interface.
type JSONLWriter struct {
writer *json.Encoder
outNULL string
needNULL bool
}
// NewJSONLWriter returns JSONLWriter.
func NewJSONLWriter(writeOpts *WriteOpts) *JSONLWriter {
w := &JSONLWriter{}
w.writer = json.NewEncoder(writeOpts.OutStream)
w.needNULL = writeOpts.OutNeedNULL
w.outNULL = writeOpts.OutNULL
return w
}
// PreWrite does nothing.
func (w *JSONLWriter) PreWrite(columns []string, types []string) error {
return nil
}
// WriteRow is write one JSONL.
func (w *JSONLWriter) WriteRow(values []any, columns []string) error {
m := orderedmap.New()
for i, col := range values {
m.Set(columns[i], compatibleJSON(col, w.needNULL, w.outNULL))
}
return w.writer.Encode(m)
}
// PostWrite does nothing.
func (w *JSONLWriter) PostWrite() error {
return nil
}
package trdsql
import (
"bufio"
)
// LTSVWriter provides methods of the Writer interface.
type LTSVWriter struct {
writer *bufio.Writer
outNULL string
results []string
delimiter rune
needNULL bool
}
// NewLTSVWriter returns LTSVWriter.
func NewLTSVWriter(writeOpts *WriteOpts) *LTSVWriter {
w := <SVWriter{}
w.delimiter = '\t'
w.writer = bufio.NewWriter(writeOpts.OutStream)
w.needNULL = writeOpts.OutNeedNULL
w.outNULL = writeOpts.OutNULL
return w
}
// PreWrite is area preparation.
func (w *LTSVWriter) PreWrite(columns []string, types []string) error {
w.results = make([]string, len(columns))
return nil
}
// WriteRow is row write to LTSV.
func (w *LTSVWriter) WriteRow(values []any, labels []string) error {
for n, col := range values {
if n > 0 {
if _, err := w.writer.WriteRune(w.delimiter); err != nil {
return err
}
}
if _, err := w.writer.WriteString(labels[n]); err != nil {
return err
}
if err := w.writer.WriteByte(':'); err != nil {
return err
}
str := ValString(col)
if col == nil && w.needNULL {
str = w.outNULL
}
if _, err := w.writer.WriteString(str); err != nil {
return err
}
}
return w.writer.WriteByte('\n')
}
// PostWrite is flush.
func (w *LTSVWriter) PostWrite() error {
return w.writer.Flush()
}
package trdsql
import (
"bufio"
"strconv"
)
// RAWWriter provides methods of the Writer interface.
type RAWWriter struct {
writer *bufio.Writer
delimiter string
endLine string
outNULL string
outHeader bool
needNULL bool
}
// NewRAWWriter returns RAWWriter.
func NewRAWWriter(writeOpts *WriteOpts) *RAWWriter {
delimiter, err := strconv.Unquote(`"` + writeOpts.OutDelimiter + `"`)
if err != nil {
debug.Printf("%s\n", err)
}
w := &RAWWriter{}
w.writer = bufio.NewWriter(writeOpts.OutStream)
w.delimiter = delimiter
w.outHeader = writeOpts.OutHeader
w.endLine = "\n"
if writeOpts.OutUseCRLF {
w.endLine = "\r\n"
}
w.needNULL = writeOpts.OutNeedNULL
w.outNULL = writeOpts.OutNULL
return w
}
// PreWrite is output of header and preparation.
func (w *RAWWriter) PreWrite(columns []string, types []string) error {
if !w.outHeader {
return nil
}
for n, col := range columns {
if n > 0 {
if _, err := w.writer.WriteString(w.delimiter); err != nil {
return err
}
}
if _, err := w.writer.WriteString(col); err != nil {
return err
}
}
_, err := w.writer.WriteString(w.endLine)
return err
}
// WriteRow is row write.
func (w *RAWWriter) WriteRow(values []any, _ []string) error {
for n, col := range values {
if n > 0 {
if _, err := w.writer.WriteString(w.delimiter); err != nil {
return err
}
}
str := ValString(col)
if col == nil && w.needNULL {
str = w.outNULL
}
if _, err := w.writer.WriteString(str); err != nil {
return err
}
}
return w.writer.WriteByte('\n')
}
// PostWrite is flush.
func (w *RAWWriter) PostWrite() error {
return w.writer.Flush()
}
package trdsql
// SliceWriter is a structure to receive the result in slice.
type SliceWriter struct {
Table [][]any
}
// NewSliceWriter return SliceWriter.
func NewSliceWriter() *SliceWriter {
return &SliceWriter{}
}
// PreWrite prepares the area.
func (w *SliceWriter) PreWrite(columns []string, types []string) error {
w.Table = make([][]any, 0)
return nil
}
// WriteRow stores the result in Table.
func (w *SliceWriter) WriteRow(values []any, columns []string) error {
row := make([]any, len(values))
copy(row, values)
w.Table = append(w.Table, row)
return nil
}
// PostWrite does nothing.
func (w *SliceWriter) PostWrite() error {
return nil
}
package trdsql
import (
"strings"
"github.com/olekukonko/tablewriter"
)
// TWWriter provides methods of the Writer interface.
type TWWriter struct {
writeOpts *WriteOpts
writer *tablewriter.Table
outNULL string
results []string
needNULL bool
markdown bool
}
// NewTWWriter returns TWWriter.
func NewTWWriter(writeOpts *WriteOpts, markdown bool) *TWWriter {
w := &TWWriter{}
w.writeOpts = writeOpts
w.needNULL = writeOpts.OutNeedNULL
w.outNULL = writeOpts.OutNULL
w.markdown = markdown
return w
}
// PreWrite is preparation.
func (w *TWWriter) PreWrite(columns []string, types []string) error {
w.writer = tablewriter.NewWriter(w.writeOpts.OutStream)
w.writer.SetAutoFormatHeaders(false)
w.writer.SetAutoWrapText(!w.writeOpts.OutNoWrap)
if w.markdown {
w.writer.SetBorders(tablewriter.Border{Left: true, Top: false, Right: true, Bottom: false})
w.writer.SetCenterSeparator("|")
}
w.writer.SetHeader(columns)
w.results = make([]string, len(columns))
return nil
}
// WriteRow is Addition to array.
func (w *TWWriter) WriteRow(values []any, columns []string) error {
for i, col := range values {
str := ValString(col)
if w.markdown {
str = strings.ReplaceAll(str, `|`, `\|`)
}
if col == nil && w.needNULL {
str = w.outNULL
}
w.results[i] = str
}
w.writer.Append(w.results)
return nil
}
// PostWrite is actual output.
func (w *TWWriter) PostWrite() error {
w.writer.Render()
return nil
}
package trdsql
import (
"strings"
"github.com/noborus/tbln"
)
// TBLNWriter provides methods of the Writer interface.
type TBLNWriter struct {
writer *tbln.Writer
outNULL string
results []string
needNULL bool
}
// NewTBLNWriter returns TBLNWriter.
func NewTBLNWriter(writeOpts *WriteOpts) *TBLNWriter {
w := &TBLNWriter{}
w.writer = tbln.NewWriter(writeOpts.OutStream)
w.needNULL = writeOpts.OutNeedNULL
w.outNULL = writeOpts.OutNULL
return w
}
// PreWrite is prepare tbln definition body.
func (w *TBLNWriter) PreWrite(columns []string, types []string) error {
d := tbln.NewDefinition()
if err := d.SetNames(columns); err != nil {
return err
}
if err := d.SetTypes(ConvertTypes(types)); err != nil {
return err
}
if err := w.writer.WriteDefinition(d); err != nil {
return err
}
w.results = make([]string, len(columns))
return nil
}
// WriteRow is row write.
func (w *TBLNWriter) WriteRow(values []any, columns []string) error {
for i, col := range values {
str := ValString(col)
if col == nil && w.needNULL {
str = w.outNULL
}
w.results[i] = strings.ReplaceAll(str, "\n", "\\n")
}
return w.writer.WriteRow(w.results)
}
// PostWrite is nil.
func (w *TBLNWriter) PostWrite() error {
return nil
}
// ConvertTypes is converts database types to common types.
func ConvertTypes(dbTypes []string) []string {
ret := make([]string, len(dbTypes))
for i, t := range dbTypes {
ret[i] = convertType(t)
}
return ret
}
func convertType(dbType string) string {
switch strings.ToLower(dbType) {
case "smallint", "integer", "int", "int2", "int4", "smallserial", "serial":
return "int"
case "bigint", "int8", "bigserial":
return "bigint"
case "float", "decimal", "numeric", "real", "double precision":
return "numeric"
case "bool":
return "bool"
case "timestamp", "timestamptz", "date", "time":
return "timestamp"
case "string", "text", "char", "varchar":
return "text"
default:
return "text"
}
}
package trdsql
import (
"bufio"
"fmt"
"strings"
runewidth "github.com/mattn/go-runewidth"
"golang.org/x/term"
)
// VFWriter is Vertical Format output.
type VFWriter struct {
writer *bufio.Writer
outNULL string
header []string
termWidth int
hSize int
count int
needNULL bool
}
// NewVFWriter returns VFWriter.
func NewVFWriter(writeOpts *WriteOpts) *VFWriter {
var err error
w := &VFWriter{}
w.writer = bufio.NewWriter(writeOpts.OutStream)
w.termWidth, _, err = term.GetSize(0)
if err != nil {
w.termWidth = 40
}
w.needNULL = writeOpts.OutNeedNULL
w.outNULL = writeOpts.OutNULL
return w
}
// PreWrite is preparation.
func (w *VFWriter) PreWrite(columns []string, types []string) error {
w.count = 0
w.header = make([]string, len(columns))
w.hSize = 0
for i, col := range columns {
if w.hSize < runewidth.StringWidth(col) {
w.hSize = runewidth.StringWidth(col)
}
w.header[i] = col
}
return nil
}
// WriteRow is actual output.
func (w *VFWriter) WriteRow(values []any, columns []string) error {
w.count++
_, err := fmt.Fprintf(w.writer,
"---[ %d]%s\n", w.count, strings.Repeat("-", (w.termWidth-16)))
if err != nil {
debug.Printf("%s\n", err)
}
for i, col := range w.header {
v := w.hSize - runewidth.StringWidth(col)
str := ValString(values[i])
if values[i] == nil && w.needNULL {
str = w.outNULL
}
_, err := fmt.Fprintf(w.writer,
"%s%s | %-s\n",
strings.Repeat(" ", v+2),
col,
str)
if err != nil {
debug.Printf("%s\n", err)
}
}
return nil
}
// PostWrite is flush.
func (w *VFWriter) PostWrite() error {
return w.writer.Flush()
}
package trdsql
import (
"encoding/hex"
"unicode/utf8"
"github.com/goccy/go-yaml"
)
// YAMLWriter provides methods of the Writer interface.
type YAMLWriter struct {
writer *yaml.Encoder
outNULL string
results []yaml.MapSlice
needNULL bool
}
// NewYAMLWriter returns YAMLWriter.
func NewYAMLWriter(writeOpts *WriteOpts) *YAMLWriter {
w := &YAMLWriter{}
w.writer = yaml.NewEncoder(writeOpts.OutStream)
w.needNULL = writeOpts.OutNeedNULL
w.outNULL = writeOpts.OutNULL
return w
}
// PreWrite is area preparation.
func (w *YAMLWriter) PreWrite(columns []string, types []string) error {
w.results = make([]yaml.MapSlice, 0)
return nil
}
// WriteRow is Addition to array.
func (w *YAMLWriter) WriteRow(values []any, columns []string) error {
m := make(yaml.MapSlice, len(values))
for i, col := range values {
m[i].Key = columns[i]
m[i].Value = compatibleYAML(col, w.needNULL, w.outNULL)
}
w.results = append(w.results, m)
return nil
}
// CompatibleYAML converts the value to a YAML-compatible value.
func compatibleYAML(v any, needNULL bool, outNULL string) any {
var yl any
switch t := v.(type) {
case []byte:
if err := yaml.Unmarshal(t, &yl); err == nil {
return yl
}
if ok := utf8.Valid(t); ok {
return string(t)
}
return `\x` + hex.EncodeToString(t)
case string:
y := []byte(t)
if err := yaml.Unmarshal(y, &yl); err == nil {
return yl
}
return v
default:
if needNULL {
return outNULL
}
return v
}
}
// PostWrite is actual output.
func (w *YAMLWriter) PostWrite() error {
return w.writer.Encode(w.results)
}
package trdsql
import (
"io"
"log"
"sync"
)
// extToFormat is a map of file extensions to formats.
var extToFormat map[string]Format = map[string]Format{
"CSV": CSV,
"LTSV": LTSV,
"JSON": JSON,
"JSONL": JSON,
"YAML": YAML,
"YML": YAML,
"TBLN": TBLN,
"TSV": TSV,
"PSV": PSV,
"WIDTH": WIDTH,
"TEXT": TEXT,
}
// ReaderFunc is a function that creates a new Reader.
type ReaderFunc func(io.Reader, *ReadOpts) (Reader, error)
// readerFuncs maps formats to their corresponding ReaderFunc.
var readerFuncs = map[Format]ReaderFunc{
CSV: func(reader io.Reader, opts *ReadOpts) (Reader, error) {
return NewCSVReader(reader, opts)
},
LTSV: func(reader io.Reader, opts *ReadOpts) (Reader, error) {
return NewLTSVReader(reader, opts)
},
JSON: func(reader io.Reader, opts *ReadOpts) (Reader, error) {
return NewJSONReader(reader, opts)
},
YAML: func(reader io.Reader, opts *ReadOpts) (Reader, error) {
return NewYAMLReader(reader, opts)
},
TBLN: func(reader io.Reader, opts *ReadOpts) (Reader, error) {
return NewTBLNReader(reader, opts)
},
TSV: func(reader io.Reader, opts *ReadOpts) (Reader, error) {
return NewTSVReader(reader, opts)
},
PSV: func(reader io.Reader, opts *ReadOpts) (Reader, error) {
return NewPSVReader(reader, opts)
},
WIDTH: func(reader io.Reader, opts *ReadOpts) (Reader, error) {
return NewGWReader(reader, opts)
},
TEXT: func(reader io.Reader, opts *ReadOpts) (Reader, error) {
return NewTextReader(reader, opts)
},
}
var (
// extFormat is the next format number to be assigned.
extFormat Format = 100
// registerMux is a mutex to protect access to the register.
registerMux = &sync.Mutex{}
)
func RegisterReaderFunc(ext string, readerFunc ReaderFunc) {
registerMux.Lock()
defer registerMux.Unlock()
extToFormat[ext] = extFormat
readerFuncs[extFormat] = readerFunc
extFormat++
}
// Reader is wrap the reader.
// Reader reads from tabular files.
type Reader interface {
// Names returns column names.
Names() ([]string, error)
// Types returns column types.
Types() ([]string, error)
// PreReadRow is returns only columns that store preRead rows.
PreReadRow() [][]any
// ReadRow is read the rest of the row.
ReadRow(row []any) ([]any, error)
}
// ReadOpts represents options that determine the behavior of the reader.
type ReadOpts struct {
// InDelimiter is the field delimiter.
// default is ','
InDelimiter string
// InNULL is a string to replace with NULL.
InNULL string
// InJQuery is a jq expression.
InJQuery string
// InFormat is read format.
// The supported format is CSV/LTSV/JSON/TBLN.
InFormat Format
realFormat Format
// InPreRead is number of rows to read ahead.
// CSV/LTSV reads the specified number of rows to
// determine the number of columns.
InPreRead int
// InSkip is number of rows to skip.
// Skip reading specified number of lines.
InSkip int
// InLimitRead is limit read.
InLimitRead bool
// InHeader is true if there is a header.
// It is used as a column name.
InHeader bool
// InNeedNULL is true, replace InNULL with NULL.
InNeedNULL bool
// IsTemporary is a flag whether to make temporary table.
// default is true.
IsTemporary bool
// InRowNumber is row number.
InRowNumber bool
}
// NewReadOpts Returns ReadOpts.
func NewReadOpts(options ...ReadOpt) *ReadOpts {
readOpts := &ReadOpts{
InFormat: GUESS,
InPreRead: 1,
InLimitRead: false,
InSkip: 0,
InDelimiter: ",",
InHeader: false,
IsTemporary: true,
InJQuery: "",
InNeedNULL: false,
InNULL: "",
}
for _, option := range options {
option(readOpts)
}
return readOpts
}
// ReadOpt returns a *ReadOpts structure.
// Used when calling NewImporter.
type ReadOpt func(*ReadOpts)
// InFormat is read format.
func InFormat(f Format) ReadOpt {
return func(args *ReadOpts) {
args.InFormat = f
}
}
// InPreRead is number of lines to read ahead.
func InPreRead(p int) ReadOpt {
return func(args *ReadOpts) {
args.InPreRead = p
}
}
func InLimitRead(p bool) ReadOpt {
return func(args *ReadOpts) {
args.InLimitRead = p
}
}
// InJQ is jq expression.
func InJQ(p string) ReadOpt {
return func(args *ReadOpts) {
args.InJQuery = p
}
}
// InSkip is number of lines to skip.
func InSkip(s int) ReadOpt {
return func(args *ReadOpts) {
args.InSkip = s
}
}
// InDelimiter is the field delimiter.
func InDelimiter(d string) ReadOpt {
return func(args *ReadOpts) {
args.InDelimiter = d
}
}
// InHeader is true if there is a header.
func InHeader(h bool) ReadOpt {
return func(args *ReadOpts) {
args.InHeader = h
}
}
// InNeedNULL sets a flag as to whether it should be replaced with NULL.
func InNeedNULL(n bool) ReadOpt {
return func(args *ReadOpts) {
args.InNeedNULL = n
}
}
// In NULL is a string to replace with NULL.
func InNULL(s string) ReadOpt {
return func(args *ReadOpts) {
args.InNULL = s
}
}
// IsTemporary is a flag whether to make temporary table.
func IsTemporary(t bool) ReadOpt {
return func(args *ReadOpts) {
args.IsTemporary = t
}
}
func InRowNumber(t bool) ReadOpt {
return func(args *ReadOpts) {
args.InRowNumber = t
}
}
// NewReader returns an Reader interface
// depending on the file to be imported.
func NewReader(reader io.Reader, readOpts *ReadOpts) (Reader, error) {
if reader == nil {
return nil, ErrNoReader
}
readerFunc, ok := readerFuncs[readOpts.realFormat]
if !ok {
return nil, ErrUnknownFormat
}
return readerFunc(reader, readOpts)
}
func skipRead(r Reader, skipNum int) {
skip := make([]any, 1)
for i := 0; i < skipNum; i++ {
row, err := r.ReadRow(skip)
if err != nil {
log.Printf("ERROR: skip error %s", err)
break
}
debug.Printf("Skip row:%s\n", row)
}
}
package trdsql
import (
"encoding/hex"
"fmt"
"strconv"
"time"
"unicode/utf8"
)
// ValString converts database value to string.
func ValString(v any) string {
switch t := v.(type) {
case nil:
return ""
case string:
return t
case []byte:
if ok := utf8.Valid(t); ok {
return string(t)
}
return `\x` + hex.EncodeToString(t)
case int:
return strconv.Itoa(t)
case int32:
return strconv.FormatInt(int64(t), 10)
case int64:
return strconv.FormatInt(t, 10)
case time.Time:
return t.Format(time.RFC3339)
default:
return fmt.Sprint(v)
}
}
func replaceNULL(nullString string, v any) any {
switch t := v.(type) {
case nil:
return nil
case string:
if t == nullString {
return nil
}
case []byte:
if string(t) == nullString {
return nil
}
}
return v
}
// Package trdsql implements execute SQL queries on tabular data.
//
// trdsql imports tabular data into a database,
// executes SQL queries, and executes exports.
package trdsql
import (
"context"
"fmt"
"log"
)
// AppName is used for command names.
var AppName = "trdsql"
// multiT is a flag for multiple queries.
type multiT bool
// multi is a flag for multiple queries.
var multi = multiT(false)
// EnableMultipleQueries enables multiple queries.
func EnableMultipleQueries() {
multi = true
}
// TRDSQL represents DB definition and Importer/Exporter interface.
type TRDSQL struct {
// Importer is interface of processing to
// import(create/insert) data.
Importer Importer
// Exporter is interface export to the process of
// export(select) from the database.
Exporter Exporter
// Driver is database driver name(sqlite3/sqlite/mysql/postgres).
Driver string
// Dsn is data source name.
Dsn string
}
// NewTRDSQL returns a new TRDSQL structure.
func NewTRDSQL(im Importer, ex Exporter) *TRDSQL {
return &TRDSQL{
Driver: DefaultDriver,
Dsn: "",
Importer: im,
Exporter: ex,
}
}
// Format represents the import/export format.
type Format int
// Represents Format.
const (
// import (guesses for import format).
GUESS Format = iota
// import/export
// Format using go standard CSV library.
CSV
// import/export
// Labeled Tab-separated Values.
LTSV
// import/export
// Format using go standard JSON library.
JSON
// import/export
// TBLN format(https://tbln.dev).
TBLN
// import
// Format using guesswidth library.
WIDTH
// import
TEXT
// export
// Output as it is.
// Multiple characters can be selected as delimiter.
RAW
// export
// MarkDown format.
MD
// export
// ASCII Table format.
AT
// export
// Vertical format.
VF
// export
// JSON Lines format(http://jsonlines.org/).
JSONL
// import/export
// YAML format.
YAML
// import
// Tab-Separated Values format. Format using go standard CSV library.
TSV
// import
// Pipe-Separated Values format. Format using go standard CSV library.
PSV
)
// String returns the string representation of the Format.
func (f Format) String() string {
switch f {
case GUESS:
return "GUESS"
case CSV:
return "CSV"
case LTSV:
return "LTSV"
case JSON:
return "JSON"
case TBLN:
return "TBLN"
case WIDTH:
return "WIDTH"
case RAW:
return "RAW"
case MD:
return "MD"
case AT:
return "AT"
case VF:
return "VF"
case JSONL:
return "JSONL"
case TSV:
return "TSV"
case PSV:
return "PSV"
case YAML:
return "YAML"
default:
return "Unknown"
}
}
// Exec is actually executed.
func (trd *TRDSQL) Exec(sql string) error {
ctx := context.Background()
return trd.ExecContext(ctx, sql)
}
// ExecContext is actually executed.
func (trd *TRDSQL) ExecContext(ctx context.Context, sqlQuery string) error {
db, err := Connect(trd.Driver, trd.Dsn)
if err != nil {
return fmt.Errorf("connect: %w", err)
}
defer func() {
if deferErr := db.Disconnect(); deferErr != nil {
log.Printf("disconnect: %s", deferErr)
}
}()
db.Tx, err = db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("begin: %w", err)
}
if trd.Importer != nil {
sqlQuery, err = trd.Importer.ImportContext(ctx, db, sqlQuery)
if err != nil {
return fmt.Errorf("import: %w", err)
}
}
if trd.Exporter != nil {
if err := trd.Exporter.ExportContext(ctx, db, sqlQuery); err != nil {
return fmt.Errorf("export: %w", err)
}
}
if err := db.Tx.Commit(); err != nil {
return fmt.Errorf("commit: %w", err)
}
return nil
}
package trdsql
import (
"io"
"os"
)
// extToOutFormat is a map of file extensions to formats.
var extToOutFormat = map[string]Format{
"CSV": CSV,
"LTSV": LTSV,
"JSON": JSON,
"JSONL": JSONL,
"TBLN": TBLN,
"RAW": RAW,
"MD": MD,
"AT": AT,
"VF": VF,
"YAML": YAML,
"YML": YAML,
}
// Writer is an interface that wraps the Write method that writes from the database to a file.
// Writer is a group of methods called from Export.
type Writer interface {
// PreWrite is called first to write.
// The arguments are a list of column names and a list of type names.
PreWrite(columns []string, types []string) error
// WriteRow is row write.
WriteRow(row []any, columns []string) error
// PostWrite is called last in the write.
PostWrite() error
}
// WriteOpts represents options that determine the behavior of the writer.
type WriteOpts struct {
// OutStream is the output destination.
OutStream io.Writer
// ErrStream is the error output destination.
ErrStream io.Writer
// OutDelimiter is the output delimiter (Use only CSV and Raw).
OutDelimiter string
// OutQuote is the output quote character (Use only CSV).
OutQuote string
// OutNeedNULL is true, replace NULL with OutNULL.
OutNULL string
// OutFormat is the writing format.
OutFormat Format
// OutAllQuotes is true if Enclose all fields (Use only CSV).
OutAllQuotes bool
// True to use \r\n as the line terminator (Use only CSV).
OutUseCRLF bool
// OutHeader is true if it outputs a header(Use only CSV and Raw).
OutHeader bool
// OutNoWrap is true, do not wrap long columns(Use only AT and MD).
OutNoWrap bool
// OutNeedNULL is true, replace NULL with OutNULL.
OutNeedNULL bool
// OutJSONToYAML is true, convert JSON to YAML(Use only YAML).
OutJSONToYAML bool
}
// WriteOpt is a function to set WriteOpts.
type WriteOpt func(*WriteOpts)
// OutFormat sets Format.
func OutFormat(f Format) WriteOpt {
return func(args *WriteOpts) {
args.OutFormat = f
}
}
// OutDelimiter sets delimiter.
func OutDelimiter(d string) WriteOpt {
return func(args *WriteOpts) {
args.OutDelimiter = d
}
}
// OutQuote sets quote.
func OutQuote(q string) WriteOpt {
return func(args *WriteOpts) {
args.OutQuote = q
}
}
// OutUseCRLF sets use CRLF.
func OutUseCRLF(c bool) WriteOpt {
return func(args *WriteOpts) {
args.OutUseCRLF = c
}
}
// OutAllQuotes sets all quotes.
func OutAllQuotes(a bool) WriteOpt {
return func(args *WriteOpts) {
args.OutAllQuotes = a
}
}
// OutHeader sets flag to output header.
func OutHeader(h bool) WriteOpt {
return func(args *WriteOpts) {
args.OutHeader = h
}
}
// OutNoWrap sets flag to output do not wrap long columns.
func OutNoWrap(w bool) WriteOpt {
return func(args *WriteOpts) {
args.OutNoWrap = w
}
}
// OutNeedNULL sets a flag to replace NULL.
func OutNeedNULL(n bool) WriteOpt {
return func(args *WriteOpts) {
args.OutNeedNULL = n
}
}
// OutNULL sets the output NULL string.
func OutNULL(s string) WriteOpt {
return func(args *WriteOpts) {
args.OutNULL = s
}
}
// OutStream sets the output destination.
func OutStream(w io.Writer) WriteOpt {
return func(args *WriteOpts) {
args.OutStream = w
}
}
// ErrStream sets the error output destination.
func ErrStream(w io.Writer) WriteOpt {
return func(args *WriteOpts) {
args.ErrStream = w
}
}
// NewWriter returns a Writer interface.
// The argument is an option of Functional Option Pattern.
//
// usage:
//
// NewWriter(
// trdsql.OutFormat(trdsql.CSV),
// trdsql.OutHeader(true),
// trdsql.OutDelimiter(";"),
// )
func NewWriter(options ...WriteOpt) Writer {
writeOpts := &WriteOpts{
OutFormat: CSV,
OutDelimiter: ",",
OutQuote: "\"",
OutAllQuotes: false,
OutUseCRLF: false,
OutHeader: false,
OutNeedNULL: false,
OutNULL: "",
OutStream: os.Stdout,
ErrStream: os.Stderr,
}
for _, option := range options {
option(writeOpts)
}
switch writeOpts.OutFormat {
case LTSV:
return NewLTSVWriter(writeOpts)
case JSON:
return NewJSONWriter(writeOpts)
case YAML:
return NewYAMLWriter(writeOpts)
case RAW:
return NewRAWWriter(writeOpts)
case MD:
return NewTWWriter(writeOpts, true)
case AT:
return NewTWWriter(writeOpts, false)
case VF:
return NewVFWriter(writeOpts)
case TBLN:
return NewTBLNWriter(writeOpts)
case JSONL:
return NewJSONLWriter(writeOpts)
case CSV:
return NewCSVWriter(writeOpts)
default:
return NewCSVWriter(writeOpts)
}
}
// OutputFormat returns the format from the extension.
func OutputFormat(ext string) Format {
if format, ok := extToOutFormat[ext]; ok {
return format
}
return GUESS
}