package ezdc
import (
"context"
"os"
"os/exec"
)
type ComposeFace interface {
Up(ctx context.Context) *exec.Cmd
Down(ctx context.Context) *exec.Cmd
Pull(ctx context.Context, svc ...string) *exec.Cmd
Build(ctx context.Context) *exec.Cmd
}
type DefaultComposeCmd struct {
project string
file string
}
func (c DefaultComposeCmd) cmd(ctx context.Context, args ...string) *exec.Cmd {
file := c.file
if file == "" {
file = "./docker-compose.yml"
}
cmd := exec.CommandContext(ctx, "docker",
append([]string{"compose", "-p", c.project, "-f", file}, args...)...,
)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd
}
func (c DefaultComposeCmd) Down(ctx context.Context) *exec.Cmd {
return c.cmd(ctx, "down", "-v", "--remove-orphans", "--rmi", "local", "--timeout", "0")
}
func (c DefaultComposeCmd) Pull(ctx context.Context, svcs ...string) *exec.Cmd {
if len(svcs) == 0 {
return nil
}
return c.cmd(ctx, append([]string{"pull"}, svcs...)...)
}
func (c DefaultComposeCmd) Build(ctx context.Context) *exec.Cmd {
return c.cmd(ctx, "build")
}
func (c DefaultComposeCmd) Up(ctx context.Context) *exec.Cmd {
return c.cmd(ctx, "up")
}
package ezdc
import (
"bytes"
"context"
"fmt"
"io"
"log"
"os"
"os/exec"
"os/signal"
"time"
)
func infoLog(msg string) {
log.Println("#### EZDC #### INFO " + msg)
}
func errorLog(msg string) {
log.Println("#### EZDC #### ERROR " + msg)
}
// Service configures options for a service defined in the docker compose file
type Service struct {
Name string
Pull bool // pull before starting tests
Waiter Waiter // optional, how to wait for service to be ready
}
// FileLogWriter utility to open a file for logging the docker compose output
func FileLogWriter(fileName string) *os.File {
logFile, err := os.OpenFile(fileName, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
panic(fmt.Errorf("failed to open log file %s: %w", fileName, err))
}
_ = logFile.Truncate(0)
_, _ = logFile.Seek(0, 0)
return logFile
}
type crasher interface {
Crash(exitCode int)
}
type exitCrasher struct{}
func (e exitCrasher) Crash(exitCode int) {
os.Exit(exitCode)
}
type Harness struct {
ProjectName string // name for the compose project
File string // path to the docker compose file
Services []Service // configuration for services
Logs io.Writer // where to send the docker compose logs
termSig chan os.Signal
cc ComposeFace
cleanerUppers []func(context.Context)
crasher crasher
}
// Run is the entrypoint for running your testing.M.
//
// func TestMain(m *testing.M) {
// h := Harness{.....} // configure
//
// exitCode, err := h.Run(context.Background(), m.Run)
// if err != nil {
// panic(err)
// }
// os.Exit(exitCode)
// }
func (h *Harness) Run(ctx context.Context, f func() int) (code int, err error) {
ctx, cncl := context.WithCancel(ctx)
defer cncl()
h.termSig = make(chan os.Signal)
if h.cc == nil {
h.cc = DefaultComposeCmd{
project: h.ProjectName,
file: h.File,
}
}
if h.crasher == nil {
h.crasher = exitCrasher{}
}
go func() {
<-h.termSig
infoLog("got os Signal")
cncl()
h.cleanup(10 * time.Second)
h.crasher.Crash(1)
}()
signal.Notify(h.termSig, os.Interrupt)
if err := h.startDcServices(ctx); err != nil {
return 1, err
}
if err := h.waitForServices(ctx); err != nil {
return 1, err
}
infoLog("services ready")
defer func() {
if panicErr := recover(); panicErr != nil {
err = fmt.Errorf("panic: %s", panicErr)
}
h.cleanup(10 * time.Second)
}()
return f(), nil
}
func (h Harness) withLogs(cmd *exec.Cmd) (*exec.Cmd, *bytes.Buffer) {
if h.Logs == nil {
h.Logs = os.Stdout
}
buf := &bytes.Buffer{}
cmd.Stdout = h.Logs
cmd.Stderr = io.MultiWriter(h.Logs, buf)
return cmd, buf
}
func (h Harness) startDcServices(ctx context.Context) error {
infoLog("cleaning up lingering resources")
cmd, _ := h.withLogs(h.cc.Down(ctx))
_ = cmd.Run()
toPull := gmap(
filter(h.Services, func(s Service) bool {
return s.Pull
},
), func(s Service) string {
return s.Name
})
if len(toPull) > 0 {
cmd, errBuf := h.withLogs(h.cc.Pull(ctx, toPull...))
infoLog("pulling")
if err := cmd.Run(); err != nil {
_, _ = fmt.Fprintln(os.Stderr, errBuf.String())
return fmt.Errorf("error pulling: %w", err)
}
}
cmd, errBuf := h.withLogs(h.cc.Build(ctx))
infoLog("building")
if err := cmd.Run(); err != nil {
_, _ = fmt.Fprintln(os.Stderr, errBuf.String())
return fmt.Errorf("error building: %w", err)
}
cmd, errBuf = h.withLogs(h.cc.Up(ctx))
infoLog("starting")
if err := cmd.Start(); err != nil {
_, _ = fmt.Fprintln(os.Stderr, errBuf.String())
return err
}
return nil
}
func (h Harness) waitForServices(ctx context.Context) error {
toWaitFor := filter(h.Services, func(s Service) bool {
return s.Waiter != nil
})
for _, svc := range toWaitFor {
infoLog(fmt.Sprintf("waiting for '%s'...\n", svc.Name))
if err := svc.Waiter.Wait(ctx); err != nil {
return err
}
}
return nil
}
// CleanupFunc registers a function to be run before stopping the docker compose services
func (h *Harness) CleanupFunc(f func(context.Context)) {
h.cleanerUppers = append(h.cleanerUppers, f)
}
func (h Harness) cleanup(timeout time.Duration) {
infoLog("cleaning up")
ctx, cncl := context.WithTimeout(context.Background(), timeout)
defer cncl()
for _, f := range h.cleanerUppers {
f(ctx)
}
cmd, errBuf := h.withLogs(h.cc.Down(ctx))
if err := cmd.Run(); err != nil {
_, _ = fmt.Fprintln(os.Stderr, errBuf.String())
errorLog(fmt.Sprintf("failed to run 'down': %s", err))
}
}
package ezdc
func gmap[T any, U any](t []T, f func(t T) U) (res []U) {
for _, item := range t {
out := f(item)
res = append(res, out)
}
return
}
func filter[T any](t []T, f func(t T) bool) (res []T) {
for _, item := range t {
if f(item) {
res = append(res, item)
}
}
return
}
func find[T comparable](t []T, f T) bool {
for _, item := range t {
if item == f {
return true
}
}
return false
}
package ezdc
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"strings"
"time"
)
// Waiter should implement Wait to return nil once the service is ready
type Waiter interface {
Wait(context.Context) error
}
// TcpWaiter checks if a tcp connection can be established
type TcpWaiter struct {
Interval time.Duration
Timeout time.Duration
Port int
}
func (tw TcpWaiter) host() string {
return fmt.Sprintf("localhost:%d", tw.Port)
}
func (tw TcpWaiter) Wait(ctx context.Context) error {
interval := tw.Interval
if interval == 0 {
interval = 500 * time.Millisecond
}
timeout := tw.Timeout
if timeout == 0 {
timeout = 2 * time.Second
}
for i := 0; ; i++ {
d := net.Dialer{
Timeout: timeout,
}
var c net.Conn
c, err := d.DialContext(ctx, "tcp", tw.host())
if err == nil {
_ = c.Close()
return nil
}
if errors.Is(err, context.DeadlineExceeded) {
return err
}
infoLog(fmt.Sprintf(" failed to connect: '%s'", err))
time.Sleep(interval)
}
}
// HttpWaiter ensures a healthy status code is received from a http endpoint
type HttpWaiter struct {
Interval time.Duration
RequestTimeout time.Duration
Port int
Path string
ReadyStatus []int
}
func (hw HttpWaiter) url() string {
path := hw.Path
if !strings.HasPrefix(path, "/") {
path += "/" + path
}
return fmt.Sprintf("http://localhost:%d%s", hw.Port, path)
}
func (hw HttpWaiter) Wait(ctx context.Context) error {
readyStatus := hw.ReadyStatus
if len(readyStatus) == 0 {
readyStatus = []int{200, 201, 202, 204}
}
interval := hw.Interval
if interval == 0 {
interval = 500 * time.Millisecond
}
requestTimeout := hw.RequestTimeout
if requestTimeout == 0 {
requestTimeout = 2 * time.Second
}
u := hw.url()
for i := 0; ; i++ {
if i > 0 {
time.Sleep(interval)
}
var (
res *http.Response
err error
)
func() {
ctx, cncl := context.WithTimeout(ctx, requestTimeout)
defer cncl()
req, _ := http.NewRequest("GET", u, nil)
req = req.WithContext(ctx)
res, err = http.DefaultClient.Do(req)
}()
if err == nil {
if find(readyStatus, res.StatusCode) {
return nil
}
infoLog(fmt.Sprintf("failed to connect: status='%d'", res.StatusCode))
continue
}
if errors.Is(err, context.DeadlineExceeded) {
return err
}
infoLog(fmt.Sprintf("failed to connect: err='%s'", err))
}
}