package pjson
import (
"encoding/json"
"fmt"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"reflect"
)
type Variant interface {
Variant() string
}
type Discriminator interface {
Field() string
Variants() []Variant
}
type Tagged[T Discriminator] struct {
d T
Value Variant
}
func (o Tagged[T]) MarshalJSON() ([]byte, error) {
if o.Value == nil {
return json.Marshal(o.Value)
}
variant := o.Value.Variant()
b, err := json.Marshal(o.Value)
if err != nil {
return nil, err
}
return sjson.SetBytes(b, o.d.Field(), variant)
}
func (o *Tagged[T]) UnmarshalJSON(bytes []byte) error {
if len(bytes) == 0 || string(bytes) == "null" {
return nil
}
jRes := gjson.ParseBytes(bytes)
if !jRes.IsObject() {
return fmt.Errorf("did not hold an Object")
}
variantRes := jRes.Get(o.d.Field())
if !variantRes.Exists() {
return fmt.Errorf("failed to find variant field '%s' in json object", o.d.Field())
}
variantValue := variantRes.String()
if variantValue == "" {
return fmt.Errorf("variant field '%s' was empty", o.d.Field())
}
for _, obj := range o.d.Variants() {
if obj.Variant() != variantValue {
continue
}
t := reflect.TypeOf(obj)
// a pointer works just fine, but if it's not we need to get one
if t.Kind() != reflect.Pointer {
obj = reflect.New(t).Interface().(Variant)
}
if err := json.Unmarshal([]byte(jRes.Raw), &obj); err != nil {
return fmt.Errorf("failed to unmarshal variant '%s': %w", variantValue, err)
}
o.Value = obj
return nil
}
return fmt.Errorf("no variant matched type '%s'", variantValue)
}