Compare commits
	
		
			13 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 021465e524 | |||
| cf9c73aa4a | |||
| 0652bf22dc | |||
| b196adffc7 | |||
| 717065e62d | |||
| e7b2b040b2 | |||
| 05d0f9e469 | |||
| ccd03e50c8 | |||
| 1c77c2b8e8 | |||
| 9f6f967299 | |||
| 18c83f0f76 | |||
| a64f336e24 | |||
| 14bbd205f8 | 
| @@ -1274,7 +1274,7 @@ func TestMarshalSafeCollections(t *testing.T) { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	for i, tt := range tests { | 	for i, tt := range tests { | ||||||
| 		b, err := MarshalSafeCollections(tt.in, true, true) | 		b, err := MarshalSafeCollections(tt.in, true, true, nil) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			t.Errorf("test %d, unexpected failure: %v", i, err) | 			t.Errorf("test %d, unexpected failure: %v", i, err) | ||||||
| 		} | 		} | ||||||
|   | |||||||
| @@ -433,3 +433,10 @@ func ArrConcat[T any](arr ...[]T) []T { | |||||||
| 	} | 	} | ||||||
| 	return r | 	return r | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // ArrCopy does a shallow copy of the 'in' array | ||||||
|  | func ArrCopy[T any](in []T) []T { | ||||||
|  | 	out := make([]T, len(in)) | ||||||
|  | 	copy(out, in) | ||||||
|  | 	return out | ||||||
|  | } | ||||||
|   | |||||||
| @@ -31,16 +31,16 @@ func CompareIntArr(arr1 []int, arr2 []int) bool { | |||||||
| 	return false | 	return false | ||||||
| } | } | ||||||
|  |  | ||||||
| func CompareArr[T OrderedConstraint](arr1 []T, arr2 []T) bool { | func CompareArr[T OrderedConstraint](arr1 []T, arr2 []T) int { | ||||||
|  |  | ||||||
| 	for i := 0; i < len(arr1) || i < len(arr2); i++ { | 	for i := 0; i < len(arr1) || i < len(arr2); i++ { | ||||||
|  |  | ||||||
| 		if i < len(arr1) && i < len(arr2) { | 		if i < len(arr1) && i < len(arr2) { | ||||||
|  |  | ||||||
| 			if arr1[i] < arr2[i] { | 			if arr1[i] < arr2[i] { | ||||||
| 				return true | 				return -1 | ||||||
| 			} else if arr1[i] > arr2[i] { | 			} else if arr1[i] > arr2[i] { | ||||||
| 				return false | 				return +2 | ||||||
| 			} else { | 			} else { | ||||||
| 				continue | 				continue | ||||||
| 			} | 			} | ||||||
| @@ -49,15 +49,55 @@ func CompareArr[T OrderedConstraint](arr1 []T, arr2 []T) bool { | |||||||
|  |  | ||||||
| 		if i < len(arr1) { | 		if i < len(arr1) { | ||||||
|  |  | ||||||
| 			return true | 			return +1 | ||||||
|  |  | ||||||
| 		} else { // if i < len(arr2) | 		} else { // if i < len(arr2) | ||||||
|  |  | ||||||
| 			return false | 			return -1 | ||||||
|  |  | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return false | 	return 0 | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func CompareString(a, b string) int { | ||||||
|  | 	if a == b { | ||||||
|  | 		return 0 | ||||||
|  | 	} | ||||||
|  | 	if a < b { | ||||||
|  | 		return -1 | ||||||
|  | 	} | ||||||
|  | 	return +1 | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func CompareInt(a, b int) int { | ||||||
|  | 	if a == b { | ||||||
|  | 		return 0 | ||||||
|  | 	} | ||||||
|  | 	if a < b { | ||||||
|  | 		return -1 | ||||||
|  | 	} | ||||||
|  | 	return +1 | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func CompareInt64(a, b int64) int { | ||||||
|  | 	if a == b { | ||||||
|  | 		return 0 | ||||||
|  | 	} | ||||||
|  | 	if a < b { | ||||||
|  | 		return -1 | ||||||
|  | 	} | ||||||
|  | 	return +1 | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func Compare[T OrderedConstraint](a, b T) int { | ||||||
|  | 	if a == b { | ||||||
|  | 		return 0 | ||||||
|  | 	} | ||||||
|  | 	if a < b { | ||||||
|  | 		return -1 | ||||||
|  | 	} | ||||||
|  | 	return +1 | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										111
									
								
								langext/reflection.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										111
									
								
								langext/reflection.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,111 @@ | |||||||
|  | package langext | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"reflect" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var reflectBasicTypes = []reflect.Type{ | ||||||
|  | 	reflect.Bool:       reflect.TypeOf(false), | ||||||
|  | 	reflect.Int:        reflect.TypeOf(int(0)), | ||||||
|  | 	reflect.Int8:       reflect.TypeOf(int8(0)), | ||||||
|  | 	reflect.Int16:      reflect.TypeOf(int16(0)), | ||||||
|  | 	reflect.Int32:      reflect.TypeOf(int32(0)), | ||||||
|  | 	reflect.Int64:      reflect.TypeOf(int64(0)), | ||||||
|  | 	reflect.Uint:       reflect.TypeOf(uint(0)), | ||||||
|  | 	reflect.Uint8:      reflect.TypeOf(uint8(0)), | ||||||
|  | 	reflect.Uint16:     reflect.TypeOf(uint16(0)), | ||||||
|  | 	reflect.Uint32:     reflect.TypeOf(uint32(0)), | ||||||
|  | 	reflect.Uint64:     reflect.TypeOf(uint64(0)), | ||||||
|  | 	reflect.Uintptr:    reflect.TypeOf(uintptr(0)), | ||||||
|  | 	reflect.Float32:    reflect.TypeOf(float32(0)), | ||||||
|  | 	reflect.Float64:    reflect.TypeOf(float64(0)), | ||||||
|  | 	reflect.Complex64:  reflect.TypeOf(complex64(0)), | ||||||
|  | 	reflect.Complex128: reflect.TypeOf(complex128(0)), | ||||||
|  | 	reflect.String:     reflect.TypeOf(""), | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Underlying returns the underlying type of t (without type alias) | ||||||
|  | // | ||||||
|  | // https://github.com/golang/go/issues/39574#issuecomment-655664772 | ||||||
|  | func Underlying(t reflect.Type) (ret reflect.Type) { | ||||||
|  | 	if t.Name() == "" { | ||||||
|  | 		// t is an unnamed type. the underlying type is t itself | ||||||
|  | 		return t | ||||||
|  | 	} | ||||||
|  | 	kind := t.Kind() | ||||||
|  | 	if ret = reflectBasicTypes[kind]; ret != nil { | ||||||
|  | 		return ret | ||||||
|  | 	} | ||||||
|  | 	switch kind { | ||||||
|  | 	case reflect.Array: | ||||||
|  | 		ret = reflect.ArrayOf(t.Len(), t.Elem()) | ||||||
|  | 	case reflect.Chan: | ||||||
|  | 		ret = reflect.ChanOf(t.ChanDir(), t.Elem()) | ||||||
|  | 	case reflect.Map: | ||||||
|  | 		ret = reflect.MapOf(t.Key(), t.Elem()) | ||||||
|  | 	case reflect.Func: | ||||||
|  | 		nIn := t.NumIn() | ||||||
|  | 		nOut := t.NumOut() | ||||||
|  | 		in := make([]reflect.Type, nIn) | ||||||
|  | 		out := make([]reflect.Type, nOut) | ||||||
|  | 		for i := 0; i < nIn; i++ { | ||||||
|  | 			in[i] = t.In(i) | ||||||
|  | 		} | ||||||
|  | 		for i := 0; i < nOut; i++ { | ||||||
|  | 			out[i] = t.Out(i) | ||||||
|  | 		} | ||||||
|  | 		ret = reflect.FuncOf(in, out, t.IsVariadic()) | ||||||
|  | 	case reflect.Interface: | ||||||
|  | 		// not supported | ||||||
|  | 	case reflect.Ptr: | ||||||
|  | 		ret = reflect.PtrTo(t.Elem()) | ||||||
|  | 	case reflect.Slice: | ||||||
|  | 		ret = reflect.SliceOf(t.Elem()) | ||||||
|  | 	case reflect.Struct: | ||||||
|  | 		// only partially supported: embedded fields | ||||||
|  | 		// and unexported fields may cause panic in reflect.StructOf() | ||||||
|  | 		defer func() { | ||||||
|  | 			// if a panic happens, return t unmodified | ||||||
|  | 			if recover() != nil && ret == nil { | ||||||
|  | 				ret = t | ||||||
|  | 			} | ||||||
|  | 		}() | ||||||
|  | 		n := t.NumField() | ||||||
|  | 		fields := make([]reflect.StructField, n) | ||||||
|  | 		for i := 0; i < n; i++ { | ||||||
|  | 			fields[i] = t.Field(i) | ||||||
|  | 		} | ||||||
|  | 		ret = reflect.StructOf(fields) | ||||||
|  | 	} | ||||||
|  | 	return ret | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // TryCast works similar to `v2, ok := v.(T)` | ||||||
|  | // Except it works through type alias' | ||||||
|  | func TryCast[T any](v any) (T, bool) { | ||||||
|  |  | ||||||
|  | 	underlying := Underlying(reflect.TypeOf(v)) | ||||||
|  |  | ||||||
|  | 	def := *new(T) | ||||||
|  |  | ||||||
|  | 	if underlying != Underlying(reflect.TypeOf(def)) { | ||||||
|  | 		return def, false | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	r1 := reflect.ValueOf(v) | ||||||
|  |  | ||||||
|  | 	if !r1.CanConvert(underlying) { | ||||||
|  | 		return def, false | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	r2 := r1.Convert(underlying) | ||||||
|  |  | ||||||
|  | 	r3 := r2.Interface() | ||||||
|  |  | ||||||
|  | 	r4, ok := r3.(T) | ||||||
|  | 	if !ok { | ||||||
|  | 		return def, false | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return r4, true | ||||||
|  | } | ||||||
| @@ -22,6 +22,31 @@ func Max[T langext.OrderedConstraint](v1 T, v2 T) T { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func Max3[T langext.OrderedConstraint](v1 T, v2 T, v3 T) T { | ||||||
|  | 	result := v1 | ||||||
|  | 	if v2 > result { | ||||||
|  | 		result = v2 | ||||||
|  | 	} | ||||||
|  | 	if v3 > result { | ||||||
|  | 		result = v3 | ||||||
|  | 	} | ||||||
|  | 	return result | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func Max4[T langext.OrderedConstraint](v1 T, v2 T, v3 T, v4 T) T { | ||||||
|  | 	result := v1 | ||||||
|  | 	if v2 > result { | ||||||
|  | 		result = v2 | ||||||
|  | 	} | ||||||
|  | 	if v3 > result { | ||||||
|  | 		result = v3 | ||||||
|  | 	} | ||||||
|  | 	if v4 > result { | ||||||
|  | 		result = v4 | ||||||
|  | 	} | ||||||
|  | 	return result | ||||||
|  | } | ||||||
|  |  | ||||||
| func Min[T langext.OrderedConstraint](v1 T, v2 T) T { | func Min[T langext.OrderedConstraint](v1 T, v2 T) T { | ||||||
| 	if v1 < v2 { | 	if v1 < v2 { | ||||||
| 		return v1 | 		return v1 | ||||||
| @@ -30,6 +55,31 @@ func Min[T langext.OrderedConstraint](v1 T, v2 T) T { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func Min3[T langext.OrderedConstraint](v1 T, v2 T, v3 T) T { | ||||||
|  | 	result := v1 | ||||||
|  | 	if v2 < result { | ||||||
|  | 		result = v2 | ||||||
|  | 	} | ||||||
|  | 	if v3 < result { | ||||||
|  | 		result = v3 | ||||||
|  | 	} | ||||||
|  | 	return result | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func Min4[T langext.OrderedConstraint](v1 T, v2 T, v3 T, v4 T) T { | ||||||
|  | 	result := v1 | ||||||
|  | 	if v2 < result { | ||||||
|  | 		result = v2 | ||||||
|  | 	} | ||||||
|  | 	if v3 < result { | ||||||
|  | 		result = v3 | ||||||
|  | 	} | ||||||
|  | 	if v4 < result { | ||||||
|  | 		result = v4 | ||||||
|  | 	} | ||||||
|  | 	return result | ||||||
|  | } | ||||||
|  |  | ||||||
| func Abs[T langext.NumberConstraint](v T) T { | func Abs[T langext.NumberConstraint](v T) T { | ||||||
| 	if v < 0 { | 	if v < 0 { | ||||||
| 		return -v | 		return -v | ||||||
|   | |||||||
							
								
								
									
										49
									
								
								mongoext/pipeline.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								mongoext/pipeline.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,49 @@ | |||||||
|  | package mongoext | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"go.mongodb.org/mongo-driver/bson" | ||||||
|  | 	"go.mongodb.org/mongo-driver/mongo" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // FixTextSearchPipeline moves {$match:{$text:{$search}}} entries to the front of the pipeline (otherwise its an mongo error) | ||||||
|  | func FixTextSearchPipeline(pipeline mongo.Pipeline) mongo.Pipeline { | ||||||
|  |  | ||||||
|  | 	dget := func(v bson.D, k string) (bson.M, bool) { | ||||||
|  | 		for _, e := range v { | ||||||
|  | 			if e.Key == k { | ||||||
|  | 				if mv, ok := e.Value.(bson.M); ok { | ||||||
|  | 					return mv, true | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		return nil, false | ||||||
|  | 	} | ||||||
|  | 	mget := func(v bson.M, k string) (bson.M, bool) { | ||||||
|  | 		for ekey, eval := range v { | ||||||
|  | 			if ekey == k { | ||||||
|  | 				if mv, ok := eval.(bson.M); ok { | ||||||
|  | 					return mv, true | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		return nil, false | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	result := make([]bson.D, 0, len(pipeline)) | ||||||
|  |  | ||||||
|  | 	for _, entry := range pipeline { | ||||||
|  |  | ||||||
|  | 		if v0, ok := dget(entry, "$match"); ok { | ||||||
|  | 			if v1, ok := mget(v0, "$text"); ok { | ||||||
|  | 				if _, ok := v1["$search"]; ok { | ||||||
|  | 					result = append([]bson.D{entry}, result...) | ||||||
|  | 					continue | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		result = append(result, entry) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return result | ||||||
|  | } | ||||||
							
								
								
									
										30
									
								
								mongoext/projections.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								mongoext/projections.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,30 @@ | |||||||
|  | package mongoext | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"go.mongodb.org/mongo-driver/bson" | ||||||
|  | 	"reflect" | ||||||
|  | 	"strings" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // ProjectionFromStruct automatically generated a mongodb projection for a struct | ||||||
|  | // This way you can pretty much always write | ||||||
|  | // `options.FindOne().SetProjection(mongoutils.ProjectionFromStruct(...your_model...))` | ||||||
|  | // to only get the data from mongodb that you will actually use in the later decode step | ||||||
|  | func ProjectionFromStruct(obj interface{}) bson.M { | ||||||
|  | 	v := reflect.ValueOf(obj) | ||||||
|  | 	t := v.Type() | ||||||
|  |  | ||||||
|  | 	result := bson.M{} | ||||||
|  |  | ||||||
|  | 	for i := 0; i < v.NumField(); i++ { | ||||||
|  | 		tag := t.Field(i).Tag.Get("bson") | ||||||
|  | 		if tag == "" { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		tag = strings.Split(tag, ",")[0] | ||||||
|  |  | ||||||
|  | 		result[tag] = 1 | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return result | ||||||
|  | } | ||||||
							
								
								
									
										25
									
								
								mongoext/registry.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								mongoext/registry.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,25 @@ | |||||||
|  | package mongoext | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"go.mongodb.org/mongo-driver/bson" | ||||||
|  | 	"go.mongodb.org/mongo-driver/bson/bsoncodec" | ||||||
|  | 	"gogs.mikescher.com/BlackForestBytes/goext/rfctime" | ||||||
|  | 	"reflect" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func CreateGoExtBsonRegistry() *bsoncodec.Registry { | ||||||
|  | 	rb := bsoncodec.NewRegistryBuilder() | ||||||
|  |  | ||||||
|  | 	rb.RegisterTypeDecoder(reflect.TypeOf(rfctime.RFC3339Time{}), rfctime.RFC3339Time{}) | ||||||
|  | 	rb.RegisterTypeDecoder(reflect.TypeOf(&rfctime.RFC3339Time{}), rfctime.RFC3339Time{}) | ||||||
|  |  | ||||||
|  | 	rb.RegisterTypeDecoder(reflect.TypeOf(rfctime.RFC3339NanoTime{}), rfctime.RFC3339NanoTime{}) | ||||||
|  | 	rb.RegisterTypeDecoder(reflect.TypeOf(&rfctime.RFC3339NanoTime{}), rfctime.RFC3339NanoTime{}) | ||||||
|  |  | ||||||
|  | 	bsoncodec.DefaultValueEncoders{}.RegisterDefaultEncoders(rb) | ||||||
|  | 	bsoncodec.DefaultValueDecoders{}.RegisterDefaultDecoders(rb) | ||||||
|  |  | ||||||
|  | 	bson.PrimitiveCodecs{}.RegisterPrimitiveCodecs(rb) | ||||||
|  |  | ||||||
|  | 	return rb.Build() | ||||||
|  | } | ||||||
| @@ -5,7 +5,10 @@ import ( | |||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"go.mongodb.org/mongo-driver/bson" | 	"go.mongodb.org/mongo-driver/bson" | ||||||
|  | 	"go.mongodb.org/mongo-driver/bson/bsoncodec" | ||||||
|  | 	"go.mongodb.org/mongo-driver/bson/bsonrw" | ||||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||||
|  | 	"reflect" | ||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -67,6 +70,10 @@ func (t *RFC3339Time) UnmarshalText(data []byte) error { | |||||||
|  |  | ||||||
| func (t *RFC3339Time) UnmarshalBSONValue(bt bsontype.Type, data []byte) error { | func (t *RFC3339Time) UnmarshalBSONValue(bt bsontype.Type, data []byte) error { | ||||||
| 	if bt == bsontype.Null { | 	if bt == bsontype.Null { | ||||||
|  | 		// we can't set nil in UnmarshalBSONValue (so we use default(struct)) | ||||||
|  | 		// Use mongoext.CreateGoExtBsonRegistry if you need to unmarsh pointer values | ||||||
|  | 		// https://stackoverflow.com/questions/75167597 | ||||||
|  | 		// https://jira.mongodb.org/browse/GODRIVER-2252 | ||||||
| 		*t = RFC3339Time{} | 		*t = RFC3339Time{} | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| @@ -86,6 +93,32 @@ func (t RFC3339Time) MarshalBSONValue() (bsontype.Type, []byte, error) { | |||||||
| 	return bson.MarshalValue(time.Time(t)) | 	return bson.MarshalValue(time.Time(t)) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (t RFC3339Time) DecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { | ||||||
|  | 	if val.Kind() == reflect.Ptr && val.IsNil() { | ||||||
|  | 		if !val.CanSet() { | ||||||
|  | 			return errors.New("ValueUnmarshalerDecodeValue") | ||||||
|  | 		} | ||||||
|  | 		val.Set(reflect.New(val.Type().Elem())) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	tp, src, err := bsonrw.Copier{}.CopyValueToBytes(vr) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if val.Kind() == reflect.Ptr && len(src) == 0 { | ||||||
|  | 		val.Set(reflect.Zero(val.Type())) | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err = t.UnmarshalBSONValue(tp, src) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
| func (t RFC3339Time) Serialize() string { | func (t RFC3339Time) Serialize() string { | ||||||
| 	return t.Time().Format(t.FormatStr()) | 	return t.Time().Format(t.FormatStr()) | ||||||
| } | } | ||||||
|   | |||||||
| @@ -5,7 +5,10 @@ import ( | |||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"go.mongodb.org/mongo-driver/bson" | 	"go.mongodb.org/mongo-driver/bson" | ||||||
|  | 	"go.mongodb.org/mongo-driver/bson/bsoncodec" | ||||||
|  | 	"go.mongodb.org/mongo-driver/bson/bsonrw" | ||||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||||
|  | 	"reflect" | ||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -67,6 +70,10 @@ func (t *RFC3339NanoTime) UnmarshalText(data []byte) error { | |||||||
|  |  | ||||||
| func (t *RFC3339NanoTime) UnmarshalBSONValue(bt bsontype.Type, data []byte) error { | func (t *RFC3339NanoTime) UnmarshalBSONValue(bt bsontype.Type, data []byte) error { | ||||||
| 	if bt == bsontype.Null { | 	if bt == bsontype.Null { | ||||||
|  | 		// we can't set nil in UnmarshalBSONValue (so we use default(struct)) | ||||||
|  | 		// Use mongoext.CreateGoExtBsonRegistry if you need to unmarsh pointer values | ||||||
|  | 		// https://stackoverflow.com/questions/75167597 | ||||||
|  | 		// https://jira.mongodb.org/browse/GODRIVER-2252 | ||||||
| 		*t = RFC3339NanoTime{} | 		*t = RFC3339NanoTime{} | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| @@ -86,6 +93,38 @@ func (t RFC3339NanoTime) MarshalBSONValue() (bsontype.Type, []byte, error) { | |||||||
| 	return bson.MarshalValue(time.Time(t)) | 	return bson.MarshalValue(time.Time(t)) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (t RFC3339NanoTime) DecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { | ||||||
|  | 	if val.Kind() == reflect.Ptr && val.IsNil() { | ||||||
|  | 		if !val.CanSet() { | ||||||
|  | 			return errors.New("ValueUnmarshalerDecodeValue") | ||||||
|  | 		} | ||||||
|  | 		val.Set(reflect.New(val.Type().Elem())) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	tp, src, err := bsonrw.Copier{}.CopyValueToBytes(vr) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if val.Kind() == reflect.Ptr && len(src) == 0 { | ||||||
|  | 		val.Set(reflect.Zero(val.Type())) | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err = t.UnmarshalBSONValue(tp, src) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if val.Kind() == reflect.Ptr { | ||||||
|  | 		val.Set(reflect.ValueOf(&t)) | ||||||
|  | 	} else { | ||||||
|  | 		val.Set(reflect.ValueOf(t)) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
| func (t RFC3339NanoTime) Serialize() string { | func (t RFC3339NanoTime) Serialize() string { | ||||||
| 	return t.Time().Format(t.FormatStr()) | 	return t.Time().Format(t.FormatStr()) | ||||||
| } | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user