package ezdc import ( "context" "os" "os/exec" ) type ComposeFace interface { Up(ctx context.Context) *exec.Cmd Down(ctx context.Context) *exec.Cmd Pull(ctx context.Context, svc ...string) *exec.Cmd Build(ctx context.Context) *exec.Cmd } type DefaultComposeCmd struct { project string file string } func (c DefaultComposeCmd) cmd(ctx context.Context, args ...string) *exec.Cmd { file := c.file if file == "" { file = "./docker-compose.yml" } cmd := exec.CommandContext(ctx, "docker", append([]string{"compose", "-p", c.project, "-f", file}, args...)..., ) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr return cmd } func (c DefaultComposeCmd) Down(ctx context.Context) *exec.Cmd { return c.cmd(ctx, "down", "-v", "--remove-orphans", "--rmi", "local", "--timeout", "0") } func (c DefaultComposeCmd) Pull(ctx context.Context, svcs ...string) *exec.Cmd { if len(svcs) == 0 { return nil } return c.cmd(ctx, append([]string{"pull"}, svcs...)...) } func (c DefaultComposeCmd) Build(ctx context.Context) *exec.Cmd { return c.cmd(ctx, "build") } func (c DefaultComposeCmd) Up(ctx context.Context) *exec.Cmd { return c.cmd(ctx, "up") }
package ezdc import ( "bytes" "context" "fmt" "io" "log" "os" "os/exec" "os/signal" "time" ) func infoLog(msg string) { log.Println("#### EZDC #### INFO " + msg) } func errorLog(msg string) { log.Println("#### EZDC #### ERROR " + msg) } // Service configures options for a service defined in the docker compose file type Service struct { Name string Pull bool // pull before starting tests Waiter Waiter // optional, how to wait for service to be ready } // FileLogWriter utility to open a file for logging the docker compose output func FileLogWriter(fileName string) *os.File { logFile, err := os.OpenFile(fileName, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { panic(fmt.Errorf("failed to open log file %s: %w", fileName, err)) } _ = logFile.Truncate(0) _, _ = logFile.Seek(0, 0) return logFile } type crasher interface { Crash(exitCode int) } type exitCrasher struct{} func (e exitCrasher) Crash(exitCode int) { os.Exit(exitCode) } type Harness struct { ProjectName string // name for the compose project File string // path to the docker compose file Services []Service // configuration for services Logs io.Writer // where to send the docker compose logs termSig chan os.Signal cc ComposeFace cleanerUppers []func(context.Context) crasher crasher } // Run is the entrypoint for running your testing.M. // // func TestMain(m *testing.M) { // h := Harness{.....} // configure // // exitCode, err := h.Run(context.Background(), m.Run) // if err != nil { // panic(err) // } // os.Exit(exitCode) // } func (h *Harness) Run(ctx context.Context, f func() int) (code int, err error) { ctx, cncl := context.WithCancel(ctx) defer cncl() h.termSig = make(chan os.Signal) if h.cc == nil { h.cc = DefaultComposeCmd{ project: h.ProjectName, file: h.File, } } if h.crasher == nil { h.crasher = exitCrasher{} } go func() { <-h.termSig infoLog("got os Signal") cncl() h.cleanup(10 * time.Second) h.crasher.Crash(1) }() signal.Notify(h.termSig, os.Interrupt) if err := h.startDcServices(ctx); err != nil { return 1, err } if err := h.waitForServices(ctx); err != nil { return 1, err } infoLog("services ready") defer func() { if panicErr := recover(); panicErr != nil { err = fmt.Errorf("panic: %s", panicErr) } h.cleanup(10 * time.Second) }() return f(), nil } func (h Harness) withLogs(cmd *exec.Cmd) (*exec.Cmd, *bytes.Buffer) { if h.Logs == nil { h.Logs = os.Stdout } buf := &bytes.Buffer{} cmd.Stdout = h.Logs cmd.Stderr = io.MultiWriter(h.Logs, buf) return cmd, buf } func (h Harness) startDcServices(ctx context.Context) error { infoLog("cleaning up lingering resources") cmd, _ := h.withLogs(h.cc.Down(ctx)) _ = cmd.Run() toPull := gmap( filter(h.Services, func(s Service) bool { return s.Pull }, ), func(s Service) string { return s.Name }) if len(toPull) > 0 { cmd, errBuf := h.withLogs(h.cc.Pull(ctx, toPull...)) infoLog("pulling") if err := cmd.Run(); err != nil { _, _ = fmt.Fprintln(os.Stderr, errBuf.String()) return fmt.Errorf("error pulling: %w", err) } } cmd, errBuf := h.withLogs(h.cc.Build(ctx)) infoLog("building") if err := cmd.Run(); err != nil { _, _ = fmt.Fprintln(os.Stderr, errBuf.String()) return fmt.Errorf("error building: %w", err) } cmd, errBuf = h.withLogs(h.cc.Up(ctx)) infoLog("starting") if err := cmd.Start(); err != nil { _, _ = fmt.Fprintln(os.Stderr, errBuf.String()) return err } return nil } func (h Harness) waitForServices(ctx context.Context) error { toWaitFor := filter(h.Services, func(s Service) bool { return s.Waiter != nil }) for _, svc := range toWaitFor { infoLog(fmt.Sprintf("waiting for '%s'...\n", svc.Name)) if err := svc.Waiter.Wait(ctx); err != nil { return err } } return nil } // CleanupFunc registers a function to be run before stopping the docker compose services func (h *Harness) CleanupFunc(f func(context.Context)) { h.cleanerUppers = append(h.cleanerUppers, f) } func (h Harness) cleanup(timeout time.Duration) { infoLog("cleaning up") ctx, cncl := context.WithTimeout(context.Background(), timeout) defer cncl() for _, f := range h.cleanerUppers { f(ctx) } cmd, errBuf := h.withLogs(h.cc.Down(ctx)) if err := cmd.Run(); err != nil { _, _ = fmt.Fprintln(os.Stderr, errBuf.String()) errorLog(fmt.Sprintf("failed to run 'down': %s", err)) } }
package ezdc func gmap[T any, U any](t []T, f func(t T) U) (res []U) { for _, item := range t { out := f(item) res = append(res, out) } return } func filter[T any](t []T, f func(t T) bool) (res []T) { for _, item := range t { if f(item) { res = append(res, item) } } return } func find[T comparable](t []T, f T) bool { for _, item := range t { if item == f { return true } } return false }
package ezdc import ( "context" "errors" "fmt" "net" "net/http" "strings" "time" ) // Waiter should implement Wait to return nil once the service is ready type Waiter interface { Wait(context.Context) error } // TcpWaiter checks if a tcp connection can be established type TcpWaiter struct { Interval time.Duration Timeout time.Duration Port int } func (tw TcpWaiter) host() string { return fmt.Sprintf("localhost:%d", tw.Port) } func (tw TcpWaiter) Wait(ctx context.Context) error { interval := tw.Interval if interval == 0 { interval = 500 * time.Millisecond } timeout := tw.Timeout if timeout == 0 { timeout = 2 * time.Second } for i := 0; ; i++ { d := net.Dialer{ Timeout: timeout, } var c net.Conn c, err := d.DialContext(ctx, "tcp", tw.host()) if err == nil { _ = c.Close() return nil } if errors.Is(err, context.DeadlineExceeded) { return err } infoLog(fmt.Sprintf(" failed to connect: '%s'", err)) time.Sleep(interval) } } // HttpWaiter ensures a healthy status code is received from a http endpoint type HttpWaiter struct { Interval time.Duration RequestTimeout time.Duration Port int Path string ReadyStatus []int } func (hw HttpWaiter) url() string { path := hw.Path if !strings.HasPrefix(path, "/") { path += "/" + path } return fmt.Sprintf("http://localhost:%d%s", hw.Port, path) } func (hw HttpWaiter) Wait(ctx context.Context) error { readyStatus := hw.ReadyStatus if len(readyStatus) == 0 { readyStatus = []int{200, 201, 202, 204} } interval := hw.Interval if interval == 0 { interval = 500 * time.Millisecond } requestTimeout := hw.RequestTimeout if requestTimeout == 0 { requestTimeout = 2 * time.Second } u := hw.url() for i := 0; ; i++ { if i > 0 { time.Sleep(interval) } var ( res *http.Response err error ) func() { ctx, cncl := context.WithTimeout(ctx, requestTimeout) defer cncl() req, _ := http.NewRequest("GET", u, nil) req = req.WithContext(ctx) res, err = http.DefaultClient.Do(req) }() if err == nil { if find(readyStatus, res.StatusCode) { return nil } infoLog(fmt.Sprintf("failed to connect: status='%d'", res.StatusCode)) continue } if errors.Is(err, context.DeadlineExceeded) { return err } infoLog(fmt.Sprintf("failed to connect: err='%s'", err)) } }