package automapper import ( "fmt" "reflect" ) type Config[S any, D any] struct { fieldMappings map[string]Opts destType reflect.Type destFields int srcType reflect.Type srcFields int } // New creates a new config for mapping from S (source) to D (destination) func New[S any, D any]() Config[S, D] { dest := new(D) src := new(S) destType := reflect.TypeOf(*dest) srcType := reflect.TypeOf(*src) if destType.Kind() != reflect.Struct || destType.Kind() == reflect.Ptr && destType.Elem().Kind() != reflect.Struct { // must be a pointer to a struct panic("destination type D was not a struct or pointer to one") } if srcType.Kind() != reflect.Struct || srcType.Kind() == reflect.Ptr && srcType.Elem().Kind() != reflect.Struct { // must be a pointer to a struct panic("source type T was not a struct or pointer to one") } m := Config[S, D]{ fieldMappings: map[string]Opts{}, destType: destType, destFields: destType.NumField(), srcType: srcType, srcFields: srcType.NumField(), } return m } type IncompatibleTypesErr struct { src reflect.Type dest reflect.Type } func (i IncompatibleTypesErr) Error() string { return fmt.Sprintf("destination type is %s, source is %s", i.dest, i.src) } func (m Config[S, D]) mapAny(srcType reflect.Type, srcValue reflect.Value, destType reflect.Type, destValue reflect.Value) error { switch destType.Kind() { case reflect.Struct: if srcType.Kind() != reflect.Struct { return IncompatibleTypesErr{src: srcType, dest: destType} } if srcType != destType { return IncompatibleTypesErr{src: srcType, dest: destType} } destValue.Set(srcValue) return nil case reflect.Slice: if srcType.Kind() != reflect.Slice { return IncompatibleTypesErr{src: srcType, dest: destType} } if srcType.Elem().Kind() != destType.Elem().Kind() { return IncompatibleTypesErr{src: srcType, dest: destType} } if srcValue.IsNil() { return nil } if srcType.Elem() != destType.Elem() { // need to cast destValue.Set(reflect.MakeSlice(destType, srcValue.Len(), srcValue.Len())) for i := 0; i < srcValue.Len(); i++ { elemValue := srcValue.Index(i) destValue.Index(i).Set(elemValue.Convert(destType.Elem())) } return nil } destValue.Set(srcValue) case reflect.Pointer: referencedDestType := destType.Elem() referencedSourceType := srcType.Elem() if referencedSourceType.Kind() != referencedDestType.Kind() { return IncompatibleTypesErr{src: srcType, dest: destType} } if srcValue.IsNil() { return nil } if referencedSourceType == referencedDestType { destValue.Set(srcValue) } destValue.Set(srcValue.Convert(destValue.Type())) default: if srcType.Kind() != destType.Kind() { return IncompatibleTypesErr{src: srcType, dest: destType} } if destValue.Type() != srcValue.Type() { srcValue = srcValue.Convert(destValue.Type()) } destValue.Set(srcValue) } return nil } // MapSlice maps slices... func (m Config[S, D]) MapSlice(src []S) ([]D, error) { var ret []D for _, item := range src { mappedItem, err := m.Map(item) if err != nil { return nil, err } ret = append(ret, mappedItem) } return ret, nil } // Map from S (source type) to D (destination type) func (m Config[S, D]) Map(src S) (D, error) { dest := new(D) srcValue := reflect.ValueOf(&src) srcType := m.srcType destValue := reflect.ValueOf(dest) destType := m.destType for j := 0; j < m.destFields; j++ { destFieldType := destType.Field(j) if !destFieldType.IsExported() { continue } destFieldValue := destValue.Elem().Field(j) found := false fieldMapping, ok := m.fieldMappings[destFieldType.Name] if ok { err := fieldMapping.apply(src, destFieldValue) if err != nil { return *dest, err } continue } for i := 0; i < m.srcFields; i++ { srcFieldValue := srcValue.Elem().Field(i) srcFieldType := srcType.Field(i) if destFieldType.Name == srcFieldType.Name { found = true if err := m.mapAny(srcFieldType.Type, srcFieldValue, destFieldType.Type, destFieldValue); err != nil { return *dest, err } break } } if !found { return *dest, fmt.Errorf("field '%s' not found in source type '%s'", destFieldType.Name, reflect.TypeOf(src)) } } return *dest, nil } type Opts struct { mapFunc func(any) (any, error) ignore bool } func (o Opts) apply(src any, destValue reflect.Value) error { if o.ignore { return nil } if o.mapFunc == nil { return nil } v, err := o.mapFunc(src) if err != nil { return err } destValue.Set(reflect.ValueOf(v)) return nil } // IgnoreField asks the Config to avoid mapping this field func IgnoreField() func(o *Opts) { return func(o *Opts) { o.ignore = true } } // MapField asks the config to map with the mapFunc func MapField[S any](mapFunc func(S) (any, error)) func(o *Opts) { return func(o *Opts) { o.mapFunc = func(s any) (any, error) { return mapFunc(s.(S)) } } } // findStructField looks for a field in the given struct. // The field being looked for should be a pointer to the actual struct field. // If found, the field info will be returned. Otherwise, nil will be returned. // From https://github.com/go-ozzo/ozzo-validation func findStructField(structValue reflect.Value, fieldValue reflect.Value) *reflect.StructField { ptr := fieldValue.Pointer() for i := structValue.NumField() - 1; i >= 0; i-- { sf := structValue.Type().Field(i) if ptr == structValue.Field(i).UnsafeAddr() { // do additional type comparison because it's possible that the address of // an embedded struct is the same as the first field of the embedded struct if sf.Type == fieldValue.Elem().Type() { return &sf } } if sf.Anonymous { // delve into anonymous struct to look for the field fi := structValue.Field(i) if sf.Type.Kind() == reflect.Ptr { fi = fi.Elem() } if fi.Kind() == reflect.Struct { if f := findStructField(fi, fieldValue); f != nil { return f } } } } return nil } // ForFieldName registers mapping options by the struct field's name func (m Config[S, D]) ForFieldName(name string, option func(o *Opts)) Config[S, D] { _, found := m.destType.FieldByName(name) if !found { panic(fmt.Errorf("destination has no field named %s", name)) } opts, found := m.fieldMappings[name] if !found { opts = Opts{} } option(&opts) m.fieldMappings[name] = opts return m } // ForField registers mapping options by a fieldFunc. The fieldFunc MUST return a pointer to a struct field of *D. // ForField(func(d *Foo) any { return &d.Bar }, ...) func (m Config[S, D]) ForField(fieldFunc func(d *D) any, option func(o *Opts)) Config[S, D] { d := new(D) field := fieldFunc(d) structValue := reflect.ValueOf(d) fieldValue := reflect.ValueOf(field) structValue = structValue.Elem() if fieldValue.Kind() != reflect.Ptr { panic("fieldFunc return value must be pointer to struct field") } structField := findStructField(structValue, fieldValue) if structField == nil { panic(fmt.Errorf("struct field could not be identified from fieldFunc")) } return m.ForFieldName(structField.Name, option) } // MapSlice generic helper to map []A to []B func MapSlice[A, B any](slice []A, mapper func(input A) B) (res []B) { for _, item := range slice { res = append(res, mapper(item)) } return }