package monkey
import (
"reflect"
"runtime"
"strings"
"github.com/agiledragon/gomonkey/v2"
"github.com/agiledragon/gomonkey/v2/creflect"
)
// NewPatches returns a new Patches. Be sure to call `defer Patches.Reset()` on every `Patches`.
func NewPatches() *gomonkey.Patches {
return gomonkey.NewPatches()
}
// Func applies a function patch.
func Func[Fn any](patches *gomonkey.Patches, target Fn, replacement Fn) *gomonkey.Patches {
if patches == nil {
patches = gomonkey.NewPatches()
}
return patches.ApplyFunc(target, replacement)
}
// Var applies a variable patch.
func Var[Var any](patches *gomonkey.Patches, target *Var, replacement Var) *gomonkey.Patches {
if patches == nil {
patches = gomonkey.NewPatches()
}
return patches.ApplyGlobalVar(target, replacement)
}
func getMethodName(m any) string {
fullName := runtime.FuncForPC(reflect.ValueOf(m).Pointer()).Name()
return strings.Split(fullName[strings.LastIndex(fullName, ".")+1:], "-")[0]
}
func funcToMethod(receiverType reflect.Type, doubleFunc any) reflect.Value {
rf := reflect.TypeOf(doubleFunc)
if rf.Kind() != reflect.Func {
panic("doubleFunc is not a func")
}
vf := reflect.ValueOf(doubleFunc)
inParams := make([]reflect.Type, 0, rf.NumIn()+1)
inParams = append(inParams, receiverType)
for i := 0; i < rf.NumIn(); i++ {
inParams = append(inParams, rf.In(i))
}
outParams := make([]reflect.Type, 0, rf.NumOut())
for i := 0; i < rf.NumOut(); i++ {
outParams = append(outParams, rf.Out(i))
}
funcType := reflect.FuncOf(
inParams,
outParams,
rf.IsVariadic(),
)
return reflect.MakeFunc(funcType, func(in []reflect.Value) []reflect.Value {
if funcType.IsVariadic() {
return vf.CallSlice(in[1:])
} else {
return vf.Call(in[1:])
}
})
}
// Method applies a method patch.
func Method[Receiver any, Fn any](patches *gomonkey.Patches, receiver Receiver, method Fn, replacement Fn) *gomonkey.Patches {
if patches == nil {
patches = gomonkey.NewPatches()
}
name := getMethodName(method)
if name == "" {
panic("method is not a method")
}
if name[0] >= 'A' && name[0] <= 'Z' {
return patches.ApplyMethodFunc(receiver, name, replacement)
}
m, ok := creflect.MethodByName(reflect.TypeOf(receiver), name)
if !ok {
panic("retrieve method by name failed")
}
r := reflect.TypeOf(receiver)
doubleFunc := funcToMethod(r, replacement)
return patches.ApplyCoreOnlyForPrivateMethod(m, doubleFunc)
}