Compare commits
	
		
			15 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| cf9c73aa4a | |||
| 0652bf22dc | |||
| b196adffc7 | |||
| 717065e62d | |||
| e7b2b040b2 | |||
| 05d0f9e469 | |||
| ccd03e50c8 | |||
| 1c77c2b8e8 | |||
| 9f6f967299 | |||
| 18c83f0f76 | |||
| a64f336e24 | |||
| 14bbd205f8 | |||
| cecfb0d788 | |||
| a445e6f623 | |||
| 0aa6310971 | 
| @@ -167,15 +167,30 @@ func Marshal(v any) ([]byte, error) { | ||||
| 	return buf, nil | ||||
| } | ||||
|  | ||||
| type IndentOpt struct { | ||||
| 	Prefix string | ||||
| 	Indent string | ||||
| } | ||||
|  | ||||
| // MarshalSafeCollections is like Marshal except it will marshal nil maps and | ||||
| // slices as '{}' and '[]' respectfully instead of 'null' | ||||
| func MarshalSafeCollections(v interface{}, nilSafeSlices bool, nilSafeMaps bool) ([]byte, error) { | ||||
| func MarshalSafeCollections(v interface{}, nilSafeSlices bool, nilSafeMaps bool, indent *IndentOpt) ([]byte, error) { | ||||
| 	e := &encodeState{} | ||||
| 	err := e.marshal(v, encOpts{escapeHTML: true, nilSafeSlices: nilSafeSlices, nilSafeMaps: nilSafeMaps}) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return e.Bytes(), nil | ||||
| 	b := e.Bytes() | ||||
| 	if indent != nil { | ||||
| 		var buf bytes.Buffer | ||||
| 		err = Indent(&buf, b, indent.Prefix, indent.Indent) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		return buf.Bytes(), nil | ||||
| 	} else { | ||||
| 		return e.Bytes(), nil | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // MarshalIndent is like Marshal but applies Indent to format the output. | ||||
|   | ||||
| @@ -1274,7 +1274,7 @@ func TestMarshalSafeCollections(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	for i, tt := range tests { | ||||
| 		b, err := MarshalSafeCollections(tt.in, true, true) | ||||
| 		b, err := MarshalSafeCollections(tt.in, true, true, nil) | ||||
| 		if err != nil { | ||||
| 			t.Errorf("test %d, unexpected failure: %v", i, err) | ||||
| 		} | ||||
|   | ||||
							
								
								
									
										44
									
								
								gojson/gionic.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								gojson/gionic.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,44 @@ | ||||
| package json | ||||
|  | ||||
| import ( | ||||
| 	"net/http" | ||||
| ) | ||||
|  | ||||
| // Render interface is copied from github.com/gin-gonic/gin@v1.8.1/render/render.go | ||||
| type Render interface { | ||||
| 	// Render writes data with custom ContentType. | ||||
| 	Render(http.ResponseWriter) error | ||||
| 	// WriteContentType writes custom ContentType. | ||||
| 	WriteContentType(w http.ResponseWriter) | ||||
| } | ||||
|  | ||||
| type GoJsonRender struct { | ||||
| 	Data          any | ||||
| 	NilSafeSlices bool | ||||
| 	NilSafeMaps   bool | ||||
| 	Indent        *IndentOpt | ||||
| } | ||||
|  | ||||
| func (r GoJsonRender) Render(w http.ResponseWriter) error { | ||||
| 	header := w.Header() | ||||
| 	if val := header["Content-Type"]; len(val) == 0 { | ||||
| 		header["Content-Type"] = []string{"application/json; charset=utf-8"} | ||||
| 	} | ||||
|  | ||||
| 	jsonBytes, err := MarshalSafeCollections(r.Data, r.NilSafeSlices, r.NilSafeMaps, r.Indent) | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
| 	_, err = w.Write(jsonBytes) | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (r GoJsonRender) WriteContentType(w http.ResponseWriter) { | ||||
| 	header := w.Header() | ||||
| 	if val := header["Content-Type"]; len(val) == 0 { | ||||
| 		header["Content-Type"] = []string{"application/json; charset=utf-8"} | ||||
| 	} | ||||
| } | ||||
| @@ -433,3 +433,10 @@ func ArrConcat[T any](arr ...[]T) []T { | ||||
| 	} | ||||
| 	return r | ||||
| } | ||||
|  | ||||
| // ArrCopy does a shallow copy of the 'in' array | ||||
| func ArrCopy[T any](in []T) []T { | ||||
| 	out := make([]T, len(in)) | ||||
| 	copy(out, in) | ||||
| 	return out | ||||
| } | ||||
|   | ||||
| @@ -61,3 +61,43 @@ func CompareArr[T OrderedConstraint](arr1 []T, arr2 []T) bool { | ||||
|  | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| func CompareString(a, b string) int { | ||||
| 	if a == b { | ||||
| 		return 0 | ||||
| 	} | ||||
| 	if a < b { | ||||
| 		return -1 | ||||
| 	} | ||||
| 	return +1 | ||||
| } | ||||
|  | ||||
| func CompareInt(a, b int) int { | ||||
| 	if a == b { | ||||
| 		return 0 | ||||
| 	} | ||||
| 	if a < b { | ||||
| 		return -1 | ||||
| 	} | ||||
| 	return +1 | ||||
| } | ||||
|  | ||||
| func CompareInt64(a, b int64) int { | ||||
| 	if a == b { | ||||
| 		return 0 | ||||
| 	} | ||||
| 	if a < b { | ||||
| 		return -1 | ||||
| 	} | ||||
| 	return +1 | ||||
| } | ||||
|  | ||||
| func Compare[T OrderedConstraint](a, b T) int { | ||||
| 	if a == b { | ||||
| 		return 0 | ||||
| 	} | ||||
| 	if a < b { | ||||
| 		return -1 | ||||
| 	} | ||||
| 	return +1 | ||||
| } | ||||
|   | ||||
							
								
								
									
										71
									
								
								langext/panic.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										71
									
								
								langext/panic.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,71 @@ | ||||
| package langext | ||||
|  | ||||
| type PanicWrappedErr struct { | ||||
| 	panic any | ||||
| } | ||||
|  | ||||
| func (p PanicWrappedErr) Error() string { | ||||
| 	return "A panic occured" | ||||
| } | ||||
|  | ||||
| func (p PanicWrappedErr) ReoveredObj() any { | ||||
| 	return p.panic | ||||
| } | ||||
|  | ||||
| func RunPanicSafe(fn func()) (err error) { | ||||
| 	defer func() { | ||||
| 		if rec := recover(); rec != nil { | ||||
| 			err = PanicWrappedErr{panic: rec} | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	fn() | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func RunPanicSafeR1(fn func() error) (err error) { | ||||
| 	defer func() { | ||||
| 		if rec := recover(); rec != nil { | ||||
| 			err = PanicWrappedErr{panic: rec} | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	return fn() | ||||
| } | ||||
|  | ||||
| func RunPanicSafeR2[T1 any](fn func() (T1, error)) (r1 T1, err error) { | ||||
| 	defer func() { | ||||
| 		if rec := recover(); rec != nil { | ||||
| 			r1 = *new(T1) | ||||
| 			err = PanicWrappedErr{panic: rec} | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	return fn() | ||||
| } | ||||
|  | ||||
| func RunPanicSafeR3[T1 any, T2 any](fn func() (T1, T2, error)) (r1 T1, r2 T2, err error) { | ||||
| 	defer func() { | ||||
| 		if rec := recover(); rec != nil { | ||||
| 			r1 = *new(T1) | ||||
| 			r2 = *new(T2) | ||||
| 			err = PanicWrappedErr{panic: rec} | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	return fn() | ||||
| } | ||||
|  | ||||
| func RunPanicSafeR4[T1 any, T2 any, T3 any](fn func() (T1, T2, T3, error)) (r1 T1, r2 T2, r3 T3, err error) { | ||||
| 	defer func() { | ||||
| 		if rec := recover(); rec != nil { | ||||
| 			r1 = *new(T1) | ||||
| 			r2 = *new(T2) | ||||
| 			r3 = *new(T3) | ||||
| 			err = PanicWrappedErr{panic: rec} | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	return fn() | ||||
| } | ||||
							
								
								
									
										111
									
								
								langext/reflection.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										111
									
								
								langext/reflection.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,111 @@ | ||||
| package langext | ||||
|  | ||||
| import ( | ||||
| 	"reflect" | ||||
| ) | ||||
|  | ||||
| var reflectBasicTypes = []reflect.Type{ | ||||
| 	reflect.Bool:       reflect.TypeOf(false), | ||||
| 	reflect.Int:        reflect.TypeOf(int(0)), | ||||
| 	reflect.Int8:       reflect.TypeOf(int8(0)), | ||||
| 	reflect.Int16:      reflect.TypeOf(int16(0)), | ||||
| 	reflect.Int32:      reflect.TypeOf(int32(0)), | ||||
| 	reflect.Int64:      reflect.TypeOf(int64(0)), | ||||
| 	reflect.Uint:       reflect.TypeOf(uint(0)), | ||||
| 	reflect.Uint8:      reflect.TypeOf(uint8(0)), | ||||
| 	reflect.Uint16:     reflect.TypeOf(uint16(0)), | ||||
| 	reflect.Uint32:     reflect.TypeOf(uint32(0)), | ||||
| 	reflect.Uint64:     reflect.TypeOf(uint64(0)), | ||||
| 	reflect.Uintptr:    reflect.TypeOf(uintptr(0)), | ||||
| 	reflect.Float32:    reflect.TypeOf(float32(0)), | ||||
| 	reflect.Float64:    reflect.TypeOf(float64(0)), | ||||
| 	reflect.Complex64:  reflect.TypeOf(complex64(0)), | ||||
| 	reflect.Complex128: reflect.TypeOf(complex128(0)), | ||||
| 	reflect.String:     reflect.TypeOf(""), | ||||
| } | ||||
|  | ||||
| // Underlying returns the underlying type of t (without type alias) | ||||
| // | ||||
| // https://github.com/golang/go/issues/39574#issuecomment-655664772 | ||||
| func Underlying(t reflect.Type) (ret reflect.Type) { | ||||
| 	if t.Name() == "" { | ||||
| 		// t is an unnamed type. the underlying type is t itself | ||||
| 		return t | ||||
| 	} | ||||
| 	kind := t.Kind() | ||||
| 	if ret = reflectBasicTypes[kind]; ret != nil { | ||||
| 		return ret | ||||
| 	} | ||||
| 	switch kind { | ||||
| 	case reflect.Array: | ||||
| 		ret = reflect.ArrayOf(t.Len(), t.Elem()) | ||||
| 	case reflect.Chan: | ||||
| 		ret = reflect.ChanOf(t.ChanDir(), t.Elem()) | ||||
| 	case reflect.Map: | ||||
| 		ret = reflect.MapOf(t.Key(), t.Elem()) | ||||
| 	case reflect.Func: | ||||
| 		nIn := t.NumIn() | ||||
| 		nOut := t.NumOut() | ||||
| 		in := make([]reflect.Type, nIn) | ||||
| 		out := make([]reflect.Type, nOut) | ||||
| 		for i := 0; i < nIn; i++ { | ||||
| 			in[i] = t.In(i) | ||||
| 		} | ||||
| 		for i := 0; i < nOut; i++ { | ||||
| 			out[i] = t.Out(i) | ||||
| 		} | ||||
| 		ret = reflect.FuncOf(in, out, t.IsVariadic()) | ||||
| 	case reflect.Interface: | ||||
| 		// not supported | ||||
| 	case reflect.Ptr: | ||||
| 		ret = reflect.PtrTo(t.Elem()) | ||||
| 	case reflect.Slice: | ||||
| 		ret = reflect.SliceOf(t.Elem()) | ||||
| 	case reflect.Struct: | ||||
| 		// only partially supported: embedded fields | ||||
| 		// and unexported fields may cause panic in reflect.StructOf() | ||||
| 		defer func() { | ||||
| 			// if a panic happens, return t unmodified | ||||
| 			if recover() != nil && ret == nil { | ||||
| 				ret = t | ||||
| 			} | ||||
| 		}() | ||||
| 		n := t.NumField() | ||||
| 		fields := make([]reflect.StructField, n) | ||||
| 		for i := 0; i < n; i++ { | ||||
| 			fields[i] = t.Field(i) | ||||
| 		} | ||||
| 		ret = reflect.StructOf(fields) | ||||
| 	} | ||||
| 	return ret | ||||
| } | ||||
|  | ||||
| // TryCast works similar to `v2, ok := v.(T)` | ||||
| // Except it works through type alias' | ||||
| func TryCast[T any](v any) (T, bool) { | ||||
|  | ||||
| 	underlying := Underlying(reflect.TypeOf(v)) | ||||
|  | ||||
| 	def := *new(T) | ||||
|  | ||||
| 	if underlying != Underlying(reflect.TypeOf(def)) { | ||||
| 		return def, false | ||||
| 	} | ||||
|  | ||||
| 	r1 := reflect.ValueOf(v) | ||||
|  | ||||
| 	if !r1.CanConvert(underlying) { | ||||
| 		return def, false | ||||
| 	} | ||||
|  | ||||
| 	r2 := r1.Convert(underlying) | ||||
|  | ||||
| 	r3 := r2.Interface() | ||||
|  | ||||
| 	r4, ok := r3.(T) | ||||
| 	if !ok { | ||||
| 		return def, false | ||||
| 	} | ||||
|  | ||||
| 	return r4, true | ||||
| } | ||||
| @@ -22,6 +22,31 @@ func Max[T langext.OrderedConstraint](v1 T, v2 T) T { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func Max3[T langext.OrderedConstraint](v1 T, v2 T, v3 T) T { | ||||
| 	result := v1 | ||||
| 	if v2 > result { | ||||
| 		result = v2 | ||||
| 	} | ||||
| 	if v3 > result { | ||||
| 		result = v3 | ||||
| 	} | ||||
| 	return result | ||||
| } | ||||
|  | ||||
| func Max4[T langext.OrderedConstraint](v1 T, v2 T, v3 T, v4 T) T { | ||||
| 	result := v1 | ||||
| 	if v2 > result { | ||||
| 		result = v2 | ||||
| 	} | ||||
| 	if v3 > result { | ||||
| 		result = v3 | ||||
| 	} | ||||
| 	if v4 > result { | ||||
| 		result = v4 | ||||
| 	} | ||||
| 	return result | ||||
| } | ||||
|  | ||||
| func Min[T langext.OrderedConstraint](v1 T, v2 T) T { | ||||
| 	if v1 < v2 { | ||||
| 		return v1 | ||||
| @@ -30,6 +55,31 @@ func Min[T langext.OrderedConstraint](v1 T, v2 T) T { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func Min3[T langext.OrderedConstraint](v1 T, v2 T, v3 T) T { | ||||
| 	result := v1 | ||||
| 	if v2 < result { | ||||
| 		result = v2 | ||||
| 	} | ||||
| 	if v3 < result { | ||||
| 		result = v3 | ||||
| 	} | ||||
| 	return result | ||||
| } | ||||
|  | ||||
| func Min4[T langext.OrderedConstraint](v1 T, v2 T, v3 T, v4 T) T { | ||||
| 	result := v1 | ||||
| 	if v2 < result { | ||||
| 		result = v2 | ||||
| 	} | ||||
| 	if v3 < result { | ||||
| 		result = v3 | ||||
| 	} | ||||
| 	if v4 < result { | ||||
| 		result = v4 | ||||
| 	} | ||||
| 	return result | ||||
| } | ||||
|  | ||||
| func Abs[T langext.NumberConstraint](v T) T { | ||||
| 	if v < 0 { | ||||
| 		return -v | ||||
|   | ||||
							
								
								
									
										49
									
								
								mongoext/pipeline.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								mongoext/pipeline.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,49 @@ | ||||
| package mongoext | ||||
|  | ||||
| import ( | ||||
| 	"go.mongodb.org/mongo-driver/bson" | ||||
| 	"go.mongodb.org/mongo-driver/mongo" | ||||
| ) | ||||
|  | ||||
| // FixTextSearchPipeline moves {$match:{$text:{$search}}} entries to the front of the pipeline (otherwise its an mongo error) | ||||
| func FixTextSearchPipeline(pipeline mongo.Pipeline) mongo.Pipeline { | ||||
|  | ||||
| 	dget := func(v bson.D, k string) (bson.M, bool) { | ||||
| 		for _, e := range v { | ||||
| 			if e.Key == k { | ||||
| 				if mv, ok := e.Value.(bson.M); ok { | ||||
| 					return mv, true | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 		return nil, false | ||||
| 	} | ||||
| 	mget := func(v bson.M, k string) (bson.M, bool) { | ||||
| 		for ekey, eval := range v { | ||||
| 			if ekey == k { | ||||
| 				if mv, ok := eval.(bson.M); ok { | ||||
| 					return mv, true | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 		return nil, false | ||||
| 	} | ||||
|  | ||||
| 	result := make([]bson.D, 0, len(pipeline)) | ||||
|  | ||||
| 	for _, entry := range pipeline { | ||||
|  | ||||
| 		if v0, ok := dget(entry, "$match"); ok { | ||||
| 			if v1, ok := mget(v0, "$text"); ok { | ||||
| 				if _, ok := v1["$search"]; ok { | ||||
| 					result = append([]bson.D{entry}, result...) | ||||
| 					continue | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		result = append(result, entry) | ||||
| 	} | ||||
|  | ||||
| 	return result | ||||
| } | ||||
							
								
								
									
										30
									
								
								mongoext/projections.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								mongoext/projections.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,30 @@ | ||||
| package mongoext | ||||
|  | ||||
| import ( | ||||
| 	"go.mongodb.org/mongo-driver/bson" | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| // ProjectionFromStruct automatically generated a mongodb projection for a struct | ||||
| // This way you can pretty much always write | ||||
| // `options.FindOne().SetProjection(mongoutils.ProjectionFromStruct(...your_model...))` | ||||
| // to only get the data from mongodb that you will actually use in the later decode step | ||||
| func ProjectionFromStruct(obj interface{}) bson.M { | ||||
| 	v := reflect.ValueOf(obj) | ||||
| 	t := v.Type() | ||||
|  | ||||
| 	result := bson.M{} | ||||
|  | ||||
| 	for i := 0; i < v.NumField(); i++ { | ||||
| 		tag := t.Field(i).Tag.Get("bson") | ||||
| 		if tag == "" { | ||||
| 			continue | ||||
| 		} | ||||
| 		tag = strings.Split(tag, ",")[0] | ||||
|  | ||||
| 		result[tag] = 1 | ||||
| 	} | ||||
|  | ||||
| 	return result | ||||
| } | ||||
							
								
								
									
										25
									
								
								mongoext/registry.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								mongoext/registry.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,25 @@ | ||||
| package mongoext | ||||
|  | ||||
| import ( | ||||
| 	"go.mongodb.org/mongo-driver/bson" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsoncodec" | ||||
| 	"gogs.mikescher.com/BlackForestBytes/goext/rfctime" | ||||
| 	"reflect" | ||||
| ) | ||||
|  | ||||
| func CreateGoExtBsonRegistry() *bsoncodec.Registry { | ||||
| 	rb := bsoncodec.NewRegistryBuilder() | ||||
|  | ||||
| 	rb.RegisterTypeDecoder(reflect.TypeOf(rfctime.RFC3339Time{}), rfctime.RFC3339Time{}) | ||||
| 	rb.RegisterTypeDecoder(reflect.TypeOf(&rfctime.RFC3339Time{}), rfctime.RFC3339Time{}) | ||||
|  | ||||
| 	rb.RegisterTypeDecoder(reflect.TypeOf(rfctime.RFC3339NanoTime{}), rfctime.RFC3339NanoTime{}) | ||||
| 	rb.RegisterTypeDecoder(reflect.TypeOf(&rfctime.RFC3339NanoTime{}), rfctime.RFC3339NanoTime{}) | ||||
|  | ||||
| 	bsoncodec.DefaultValueEncoders{}.RegisterDefaultEncoders(rb) | ||||
| 	bsoncodec.DefaultValueDecoders{}.RegisterDefaultDecoders(rb) | ||||
|  | ||||
| 	bson.PrimitiveCodecs{}.RegisterPrimitiveCodecs(rb) | ||||
|  | ||||
| 	return rb.Build() | ||||
| } | ||||
| @@ -5,7 +5,10 @@ import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"go.mongodb.org/mongo-driver/bson" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsoncodec" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonrw" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| 	"reflect" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| @@ -67,6 +70,10 @@ func (t *RFC3339Time) UnmarshalText(data []byte) error { | ||||
|  | ||||
| func (t *RFC3339Time) UnmarshalBSONValue(bt bsontype.Type, data []byte) error { | ||||
| 	if bt == bsontype.Null { | ||||
| 		// we can't set nil in UnmarshalBSONValue (so we use default(struct)) | ||||
| 		// Use mongoext.CreateGoExtBsonRegistry if you need to unmarsh pointer values | ||||
| 		// https://stackoverflow.com/questions/75167597 | ||||
| 		// https://jira.mongodb.org/browse/GODRIVER-2252 | ||||
| 		*t = RFC3339Time{} | ||||
| 		return nil | ||||
| 	} | ||||
| @@ -86,6 +93,32 @@ func (t RFC3339Time) MarshalBSONValue() (bsontype.Type, []byte, error) { | ||||
| 	return bson.MarshalValue(time.Time(t)) | ||||
| } | ||||
|  | ||||
| func (t RFC3339Time) DecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { | ||||
| 	if val.Kind() == reflect.Ptr && val.IsNil() { | ||||
| 		if !val.CanSet() { | ||||
| 			return errors.New("ValueUnmarshalerDecodeValue") | ||||
| 		} | ||||
| 		val.Set(reflect.New(val.Type().Elem())) | ||||
| 	} | ||||
|  | ||||
| 	tp, src, err := bsonrw.Copier{}.CopyValueToBytes(vr) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if val.Kind() == reflect.Ptr && len(src) == 0 { | ||||
| 		val.Set(reflect.Zero(val.Type())) | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	err = t.UnmarshalBSONValue(tp, src) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (t RFC3339Time) Serialize() string { | ||||
| 	return t.Time().Format(t.FormatStr()) | ||||
| } | ||||
|   | ||||
| @@ -5,7 +5,10 @@ import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"go.mongodb.org/mongo-driver/bson" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsoncodec" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonrw" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| 	"reflect" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| @@ -67,6 +70,10 @@ func (t *RFC3339NanoTime) UnmarshalText(data []byte) error { | ||||
|  | ||||
| func (t *RFC3339NanoTime) UnmarshalBSONValue(bt bsontype.Type, data []byte) error { | ||||
| 	if bt == bsontype.Null { | ||||
| 		// we can't set nil in UnmarshalBSONValue (so we use default(struct)) | ||||
| 		// Use mongoext.CreateGoExtBsonRegistry if you need to unmarsh pointer values | ||||
| 		// https://stackoverflow.com/questions/75167597 | ||||
| 		// https://jira.mongodb.org/browse/GODRIVER-2252 | ||||
| 		*t = RFC3339NanoTime{} | ||||
| 		return nil | ||||
| 	} | ||||
| @@ -86,6 +93,38 @@ func (t RFC3339NanoTime) MarshalBSONValue() (bsontype.Type, []byte, error) { | ||||
| 	return bson.MarshalValue(time.Time(t)) | ||||
| } | ||||
|  | ||||
| func (t RFC3339NanoTime) DecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { | ||||
| 	if val.Kind() == reflect.Ptr && val.IsNil() { | ||||
| 		if !val.CanSet() { | ||||
| 			return errors.New("ValueUnmarshalerDecodeValue") | ||||
| 		} | ||||
| 		val.Set(reflect.New(val.Type().Elem())) | ||||
| 	} | ||||
|  | ||||
| 	tp, src, err := bsonrw.Copier{}.CopyValueToBytes(vr) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if val.Kind() == reflect.Ptr && len(src) == 0 { | ||||
| 		val.Set(reflect.Zero(val.Type())) | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	err = t.UnmarshalBSONValue(tp, src) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if val.Kind() == reflect.Ptr { | ||||
| 		val.Set(reflect.ValueOf(&t)) | ||||
| 	} else { | ||||
| 		val.Set(reflect.ValueOf(t)) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (t RFC3339NanoTime) Serialize() string { | ||||
| 	return t.Time().Format(t.FormatStr()) | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user