package wazero_grpc_server
const (
errCodeEmpty uint32 = iota
errCodeDone
errCodeUnknown
errCodeInvalid
errCodeUnrecognized
errCodeNotImplemented
errCodeMalformed
errCodeUnexpected
errCodeMarshal
)
var (
ErrUnknown = &Error{errCodeUnknown, "Unknown"}
ErrInvalid = &Error{errCodeInvalid, "Invalid"}
ErrUnrecognized = &Error{errCodeUnrecognized, "Unrecognized"}
ErrNotImplemented = &Error{errCodeNotImplemented, "Not Implemented"}
ErrMalformed = &Error{errCodeMalformed, "Malformed"}
ErrUnexpected = &Error{errCodeUnexpected, "Unexpected"}
ErrMarshal = &Error{errCodeMarshal, "Marshal"}
)
var errorsByCode = map[uint32]error{
errCodeUnknown: ErrUnknown,
errCodeInvalid: ErrInvalid,
errCodeUnrecognized: ErrUnrecognized,
errCodeNotImplemented: ErrNotImplemented,
errCodeMalformed: ErrMalformed,
errCodeUnexpected: ErrUnexpected,
errCodeMarshal: ErrMarshal,
}
type Error struct {
code uint32
msg string
}
func (e Error) Error() string {
return e.msg
}
// Borrowed heavily from mwitkow/grpc-proxy (Apache 2.0)
// See https://github.com/mwitkow/grpc-proxy/blob/master/proxy/handler.go
package wazero_grpc_server
import (
"context"
"io"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
"github.com/pantopic/wazero-pool"
)
type grpcHandler struct {
pool wazeropool.Instance
meta *meta
}
func (h *grpcHandler) handler(f handlerFactory) func(srv any, serverStream grpc.ServerStream) error {
return func(srv any, serverStream grpc.ServerStream) error {
fullMethodName, ok := grpc.MethodFromServerStream(serverStream)
if !ok {
return status.Errorf(codes.Internal, "lowLevelServerStream not exists in context")
}
ctx, cancel := context.WithCancel(serverStream.Context())
defer cancel()
mod := h.pool.Get()
defer h.pool.Put(mod)
clientStream := f(ctx, mod, h.meta, fullMethodName)
s2cErrChan := h.forwardServerToWazero(serverStream, clientStream)
c2sErrChan := h.forwardWazeroToServer(clientStream, serverStream)
for range 2 {
select {
case s2cErr := <-s2cErrChan:
if s2cErr == io.EOF {
clientStream.CloseSend()
} else {
cancel()
return status.Errorf(codes.Internal, "failed proxying s2c: %v", s2cErr)
}
case c2sErr := <-c2sErrChan:
serverStream.SetTrailer(clientStream.Trailer())
if c2sErr != io.EOF {
return c2sErr
}
return nil
}
}
return status.Errorf(codes.Internal, "gRPC proxying should never reach this stage.")
}
}
func (h *grpcHandler) forwardWazeroToServer(src grpc.ClientStream, dst grpc.ServerStream) chan error {
ret := make(chan error, 1)
go func() {
f := &emptypb.Empty{}
for i := 0; ; i++ {
if err := src.RecvMsg(f); err != nil {
ret <- err
break
}
if i == 0 {
md, err := src.Header()
if err != nil {
ret <- err
break
}
if err := dst.SendHeader(md); err != nil {
ret <- err
break
}
}
if err := dst.SendMsg(f); err != nil {
ret <- err
break
}
}
}()
return ret
}
func (h *grpcHandler) forwardServerToWazero(src grpc.ServerStream, dst grpc.ClientStream) chan error {
ret := make(chan error, 1)
go func() {
f := &emptypb.Empty{}
for i := 0; ; i++ {
if err := src.RecvMsg(f); err != nil {
ret <- err
break
}
if err := dst.SendMsg(f); err != nil {
ret <- err
break
}
}
}()
return ret
}
package wazero_grpc_server
import (
"context"
"io"
"log"
"github.com/tetratelabs/wazero/api"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/proto"
)
func newHandlerClientStream(ctx context.Context, mod api.Module, meta *meta, method string) grpc.ClientStream {
next := make(chan []byte)
ctx = context.WithValue(ctx, DefaultCtxKeyMeta, meta)
ctx = context.WithValue(ctx, DefaultCtxKeyNext, next)
return &handlerClientStream{ctx, mod, meta, method, next, make(chan bool), false}
}
type handlerClientStream struct {
ctx context.Context
mod api.Module
meta *meta
method string
next chan []byte
done chan bool
init bool
}
func (cs *handlerClientStream) Header() (md metadata.MD, err error) {
return
}
func (cs *handlerClientStream) Trailer() (md metadata.MD) {
return
}
func (cs *handlerClientStream) CloseSend() (err error) {
close(cs.next)
return
}
func (cs *handlerClientStream) Context() context.Context {
return cs.ctx
}
func (cs *handlerClientStream) SendMsg(m any) (err error) {
msg, err := proto.Marshal(m.(proto.Message))
if err != nil {
panic(err)
}
if !cs.init {
cs.init = true
// Special case for first message
setMethod(cs.mod, cs.meta, []byte(cs.method))
setMsg(cs.mod, cs.meta, msg)
setErrCode(cs.mod, cs.meta, errCodeEmpty)
go func() {
_, err := cs.mod.ExportedFunction("grpcCall").Call(cs.ctx)
if err != nil {
log.Println(err)
}
cs.done <- true
}()
} else {
cs.next <- msg
}
return
}
func (cs *handlerClientStream) RecvMsg(m any) (err error) {
select {
case _, ok := <-cs.done:
if !ok {
return io.EOF
}
if ferr := getError(cs.mod, cs.meta); ferr != nil {
ferr.(*Error).msg += `: ` + string(msg(cs.mod, cs.meta))
return ferr
}
b := msg(cs.mod, cs.meta)
err = proto.Unmarshal(b, m.(proto.Message))
close(cs.done)
case <-cs.ctx.Done():
}
return
}
package wazero_grpc_server
import (
"context"
"io"
"github.com/tetratelabs/wazero/api"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/proto"
)
func newHandlerServerStream(ctx context.Context, mod api.Module, meta *meta, method string) grpc.ClientStream {
return &handlerServerStream{ctx, mod, meta, method, make(chan bool)}
}
type handlerServerStream struct {
ctx context.Context
mod api.Module
meta *meta
method string
ready chan bool
}
func (cs *handlerServerStream) Header() (md metadata.MD, err error) {
return
}
func (cs *handlerServerStream) Trailer() (md metadata.MD) {
return
}
func (cs *handlerServerStream) CloseSend() (err error) {
return
}
func (cs *handlerServerStream) Context() context.Context {
return cs.ctx
}
func (cs *handlerServerStream) SendMsg(m any) (err error) {
msg, err := proto.Marshal(m.(proto.Message))
if err != nil {
panic(err)
}
setMethod(cs.mod, cs.meta, []byte(cs.method))
setMsg(cs.mod, cs.meta, msg)
cs.mod.ExportedFunction("grpcCall").Call(cs.ctx)
cs.ready <- true
return
}
func (cs *handlerServerStream) RecvMsg(m any) (err error) {
select {
case _, ok := <-cs.ready:
if !ok {
return io.EOF
}
if ferr := getError(cs.mod, cs.meta); ferr != nil {
ferr.(*Error).msg += `: ` + string(msg(cs.mod, cs.meta))
return ferr
}
b := msg(cs.mod, cs.meta)
err = proto.Unmarshal(b, m.(proto.Message))
close(cs.ready)
case <-cs.ctx.Done():
}
return
}
package wazero_grpc_server
import (
"context"
"io"
"github.com/tetratelabs/wazero/api"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/proto"
)
func newHandlerUnary(ctx context.Context, mod api.Module, meta *meta, method string) grpc.ClientStream {
return &handlerUnary{ctx, mod, meta, method, make(chan bool)}
}
type handlerUnary struct {
ctx context.Context
mod api.Module
meta *meta
method string
ready chan bool
}
func (cs *handlerUnary) Header() (md metadata.MD, err error) {
return
}
func (cs *handlerUnary) Trailer() (md metadata.MD) {
return
}
func (cs *handlerUnary) CloseSend() (err error) {
return
}
func (cs *handlerUnary) Context() context.Context {
return cs.ctx
}
func (cs *handlerUnary) SendMsg(m any) (err error) {
msg, err := proto.Marshal(m.(proto.Message))
if err != nil {
panic(err)
}
setMethod(cs.mod, cs.meta, []byte(cs.method))
setMsg(cs.mod, cs.meta, msg)
cs.mod.ExportedFunction("grpcCall").Call(cs.ctx)
cs.ready <- true
return
}
func (cs *handlerUnary) RecvMsg(m any) (err error) {
select {
case _, ok := <-cs.ready:
if !ok {
return io.EOF
}
if ferr := getError(cs.mod, cs.meta); ferr != nil {
ferr.(*Error).msg += `: ` + string(msg(cs.mod, cs.meta))
return ferr
}
b := msg(cs.mod, cs.meta)
err = proto.Unmarshal(b, m.(proto.Message))
close(cs.ready)
case <-cs.ctx.Done():
}
return
}
package wazero_grpc_server
import (
"context"
"log"
"strings"
"sync"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"google.golang.org/grpc"
"github.com/pantopic/wazero-pool"
)
var (
DefaultCtxKeyMeta = `wazero_grpc_server_meta_key`
DefaultCtxKeyNext = `wazero_grpc_next`
)
type meta struct {
ptrMethodMax uint32
ptrMethodLen uint32
ptrMsgMax uint32
ptrMsgLen uint32
ptrErrCode uint32
ptrMethod uint32
ptrMsg uint32
}
type hostModule struct {
sync.RWMutex
module api.Module
ctxKeyMeta string
ctxKeyServer string
}
func New(opts ...Option) *hostModule {
p := &hostModule{
ctxKeyMeta: DefaultCtxKeyMeta,
}
for _, opt := range opts {
opt(p)
}
return p
}
// Register instantiates the host module, making it available to all module instances in this runtime
// Called once after a runtime is created, usually on startup
func (p *hostModule) Register(ctx context.Context, r wazero.Runtime) (err error) {
builder := r.NewHostModuleBuilder("grpc")
register := func(name string, fn func(ctx context.Context, m api.Module, stack []uint64)) {
builder = builder.NewFunctionBuilder().WithGoModuleFunction(api.GoModuleFunc(fn), nil, nil).Export(name)
}
for name, fn := range map[string]any{
"Recv": func(ctx context.Context) (m []byte, ok bool) {
m, ok = <-get[chan []byte](ctx, DefaultCtxKeyNext)
// log.Printf(`recv: %v`, ok)
return
},
"Send": func(ctx context.Context, m []byte) {
// Send message
},
} {
switch fn := fn.(type) {
case func(context.Context) ([]byte, bool):
register(name, func(ctx context.Context, m api.Module, stack []uint64) {
meta := get[*meta](ctx, p.ctxKeyMeta)
b, ok := fn(ctx)
if !ok {
setErrCode(m, meta, errCodeDone)
return
}
setErrCode(m, meta, errCodeEmpty)
setMsg(m, meta, b)
})
case func(context.Context, []byte):
register(name, func(ctx context.Context, m api.Module, stack []uint64) {
meta := get[*meta](ctx, p.ctxKeyMeta)
fn(ctx, msg(m, meta))
})
default:
log.Panicf("Method signature implementation missing: %#v", fn)
}
}
p.module, err = builder.Instantiate(ctx)
return
}
// initContext populates the meta page in context for a given module instance
// Called per module instance immediately after module instantiation
func (p *hostModule) initContext(ctx context.Context, m api.Module) (context.Context, *meta, error) {
stack, err := m.ExportedFunction(`grpc`).Call(ctx)
if err != nil {
return ctx, nil, err
}
meta := &meta{}
ptr := uint32(stack[0])
meta.ptrMethodMax, _ = m.Memory().ReadUint32Le(ptr)
meta.ptrMethodLen, _ = m.Memory().ReadUint32Le(ptr + 4)
meta.ptrMethod, _ = m.Memory().ReadUint32Le(ptr + 8)
meta.ptrMsgMax, _ = m.Memory().ReadUint32Le(ptr + 12)
meta.ptrMsgLen, _ = m.Memory().ReadUint32Le(ptr + 16)
meta.ptrMsg, _ = m.Memory().ReadUint32Le(ptr + 20)
meta.ptrErrCode, _ = m.Memory().ReadUint32Le(ptr + 24)
return context.WithValue(ctx, p.ctxKeyMeta, meta), meta, nil
}
// RegisterServices attaches the grpc service(s) to the grpc server
// Called once before server open, usually given a module instance pool
func (p *hostModule) RegisterServices(ctx context.Context, s *grpc.Server, pool wazeropool.Instance) (context.Context, error) {
mod := pool.Get()
defer pool.Put(mod)
ctx, meta, err := p.initContext(ctx, mod)
if err != nil {
return ctx, err
}
// msg = "/package1.ServiceName/u.method1,u.method2,c.method3/service2.ServiceName/u.method1,s.method2"
parts := strings.Split(string(msg(mod, meta)), "/")
for i := 1; i+2 <= len(parts); i += 2 {
p.registerService(s, pool, meta, parts[i], strings.Split(parts[i+1], ","))
}
return ctx, nil
}
func (p *hostModule) registerService(s *grpc.Server, pool wazeropool.Instance, meta *meta, serviceName string, methods []string) {
h := &grpcHandler{pool, meta}
fakeDesc := &grpc.ServiceDesc{
ServiceName: serviceName,
HandlerType: (*any)(nil),
}
for _, m := range methods {
parts := strings.Split(m, ".")
if len(parts) < 2 {
log.Panicf(`%s %#v`, methods, parts)
}
var d = grpc.StreamDesc{
StreamName: parts[1],
ServerStreams: true,
ClientStreams: true,
}
switch parts[0] {
case "u":
d.Handler = h.handler(newHandlerUnary)
case "c":
d.Handler = h.handler(newHandlerClientStream)
case "s":
d.Handler = h.handler(newHandlerServerStream)
}
fakeDesc.Streams = append(fakeDesc.Streams, d)
}
s.RegisterService(fakeDesc, h)
}
func (p *hostModule) Stop() (err error) {
return
}
func get[T any](ctx context.Context, key string) T {
v := ctx.Value(key)
if v == nil {
log.Panicf("Context item missing %s", key)
}
return v.(T)
}
func method(m api.Module, meta *meta) []byte {
return read(m, meta.ptrMethod, meta.ptrMethodLen, meta.ptrMethodMax)
}
func errCode(m api.Module, meta *meta) uint32 {
return readUint32(m, meta.ptrErrCode)
}
func setErrCode(m api.Module, meta *meta, code uint32) {
writeUint32(m, meta.ptrErrCode, uint32(code))
}
func methodBuf(m api.Module, meta *meta) []byte {
return read(m, meta.ptrMethod, 0, meta.ptrMethodMax)
}
func setMethod(m api.Module, meta *meta, method []byte) {
copy(methodBuf(m, meta)[:len(method)], method)
writeUint32(m, meta.ptrMethodLen, uint32(len(method)))
}
func msg(m api.Module, meta *meta) []byte {
return read(m, meta.ptrMsg, meta.ptrMsgLen, meta.ptrMsgMax)
}
func msgBuf(m api.Module, meta *meta) []byte {
return read(m, meta.ptrMsg, 0, meta.ptrMsgMax)
}
func setMsg(m api.Module, meta *meta, msg []byte) {
copy(msgBuf(m, meta)[:len(msg)], msg)
writeUint32(m, meta.ptrMsgLen, uint32(len(msg)))
}
func getError(m api.Module, meta *meta) error {
if err, ok := errorsByCode[errCode(m, meta)]; ok {
return err
}
return nil
}
func read(m api.Module, ptrData, ptrLen, ptrMax uint32) (buf []byte) {
buf, ok := m.Memory().Read(ptrData, readUint32(m, ptrMax))
if !ok {
log.Panicf("Memory.Read(%d, %d) out of range", ptrData, ptrLen)
}
return buf[:readUint32(m, ptrLen)]
}
func readUint32(m api.Module, ptr uint32) (val uint32) {
val, ok := m.Memory().ReadUint32Le(ptr)
if !ok {
log.Panicf("Memory.Read(%d) out of range", ptr)
}
return
}
func writeUint32(m api.Module, ptr uint32, val uint32) {
if ok := m.Memory().WriteUint32Le(ptr, val); !ok {
log.Panicf("Memory.Read(%d) out of range", ptr)
}
}
package wazero_grpc_server
type Option func(*hostModule)
func WithCtxKeyMeta(key string) Option {
return func(p *hostModule) {
p.ctxKeyMeta = key
}
}
func WithCtxKeyServer(key string) Option {
return func(p *hostModule) {
p.ctxKeyServer = key
}
}