Compare commits
	
		
			11 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 021465e524 | |||
| cf9c73aa4a | |||
| 0652bf22dc | |||
| b196adffc7 | |||
| 717065e62d | |||
| e7b2b040b2 | |||
| 05d0f9e469 | |||
| ccd03e50c8 | |||
| 1c77c2b8e8 | |||
| 9f6f967299 | |||
| 18c83f0f76 | 
| @@ -1274,7 +1274,7 @@ func TestMarshalSafeCollections(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	for i, tt := range tests { | ||||
| 		b, err := MarshalSafeCollections(tt.in, true, true) | ||||
| 		b, err := MarshalSafeCollections(tt.in, true, true, nil) | ||||
| 		if err != nil { | ||||
| 			t.Errorf("test %d, unexpected failure: %v", i, err) | ||||
| 		} | ||||
|   | ||||
| @@ -433,3 +433,10 @@ func ArrConcat[T any](arr ...[]T) []T { | ||||
| 	} | ||||
| 	return r | ||||
| } | ||||
|  | ||||
| // ArrCopy does a shallow copy of the 'in' array | ||||
| func ArrCopy[T any](in []T) []T { | ||||
| 	out := make([]T, len(in)) | ||||
| 	copy(out, in) | ||||
| 	return out | ||||
| } | ||||
|   | ||||
| @@ -31,16 +31,16 @@ func CompareIntArr(arr1 []int, arr2 []int) bool { | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| func CompareArr[T OrderedConstraint](arr1 []T, arr2 []T) bool { | ||||
| func CompareArr[T OrderedConstraint](arr1 []T, arr2 []T) int { | ||||
|  | ||||
| 	for i := 0; i < len(arr1) || i < len(arr2); i++ { | ||||
|  | ||||
| 		if i < len(arr1) && i < len(arr2) { | ||||
|  | ||||
| 			if arr1[i] < arr2[i] { | ||||
| 				return true | ||||
| 				return -1 | ||||
| 			} else if arr1[i] > arr2[i] { | ||||
| 				return false | ||||
| 				return +2 | ||||
| 			} else { | ||||
| 				continue | ||||
| 			} | ||||
| @@ -49,15 +49,55 @@ func CompareArr[T OrderedConstraint](arr1 []T, arr2 []T) bool { | ||||
|  | ||||
| 		if i < len(arr1) { | ||||
|  | ||||
| 			return true | ||||
| 			return +1 | ||||
|  | ||||
| 		} else { // if i < len(arr2) | ||||
|  | ||||
| 			return false | ||||
| 			return -1 | ||||
|  | ||||
| 		} | ||||
|  | ||||
| 	} | ||||
|  | ||||
| 	return false | ||||
| 	return 0 | ||||
| } | ||||
|  | ||||
| func CompareString(a, b string) int { | ||||
| 	if a == b { | ||||
| 		return 0 | ||||
| 	} | ||||
| 	if a < b { | ||||
| 		return -1 | ||||
| 	} | ||||
| 	return +1 | ||||
| } | ||||
|  | ||||
| func CompareInt(a, b int) int { | ||||
| 	if a == b { | ||||
| 		return 0 | ||||
| 	} | ||||
| 	if a < b { | ||||
| 		return -1 | ||||
| 	} | ||||
| 	return +1 | ||||
| } | ||||
|  | ||||
| func CompareInt64(a, b int64) int { | ||||
| 	if a == b { | ||||
| 		return 0 | ||||
| 	} | ||||
| 	if a < b { | ||||
| 		return -1 | ||||
| 	} | ||||
| 	return +1 | ||||
| } | ||||
|  | ||||
| func Compare[T OrderedConstraint](a, b T) int { | ||||
| 	if a == b { | ||||
| 		return 0 | ||||
| 	} | ||||
| 	if a < b { | ||||
| 		return -1 | ||||
| 	} | ||||
| 	return +1 | ||||
| } | ||||
|   | ||||
							
								
								
									
										111
									
								
								langext/reflection.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										111
									
								
								langext/reflection.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,111 @@ | ||||
| package langext | ||||
|  | ||||
| import ( | ||||
| 	"reflect" | ||||
| ) | ||||
|  | ||||
| var reflectBasicTypes = []reflect.Type{ | ||||
| 	reflect.Bool:       reflect.TypeOf(false), | ||||
| 	reflect.Int:        reflect.TypeOf(int(0)), | ||||
| 	reflect.Int8:       reflect.TypeOf(int8(0)), | ||||
| 	reflect.Int16:      reflect.TypeOf(int16(0)), | ||||
| 	reflect.Int32:      reflect.TypeOf(int32(0)), | ||||
| 	reflect.Int64:      reflect.TypeOf(int64(0)), | ||||
| 	reflect.Uint:       reflect.TypeOf(uint(0)), | ||||
| 	reflect.Uint8:      reflect.TypeOf(uint8(0)), | ||||
| 	reflect.Uint16:     reflect.TypeOf(uint16(0)), | ||||
| 	reflect.Uint32:     reflect.TypeOf(uint32(0)), | ||||
| 	reflect.Uint64:     reflect.TypeOf(uint64(0)), | ||||
| 	reflect.Uintptr:    reflect.TypeOf(uintptr(0)), | ||||
| 	reflect.Float32:    reflect.TypeOf(float32(0)), | ||||
| 	reflect.Float64:    reflect.TypeOf(float64(0)), | ||||
| 	reflect.Complex64:  reflect.TypeOf(complex64(0)), | ||||
| 	reflect.Complex128: reflect.TypeOf(complex128(0)), | ||||
| 	reflect.String:     reflect.TypeOf(""), | ||||
| } | ||||
|  | ||||
| // Underlying returns the underlying type of t (without type alias) | ||||
| // | ||||
| // https://github.com/golang/go/issues/39574#issuecomment-655664772 | ||||
| func Underlying(t reflect.Type) (ret reflect.Type) { | ||||
| 	if t.Name() == "" { | ||||
| 		// t is an unnamed type. the underlying type is t itself | ||||
| 		return t | ||||
| 	} | ||||
| 	kind := t.Kind() | ||||
| 	if ret = reflectBasicTypes[kind]; ret != nil { | ||||
| 		return ret | ||||
| 	} | ||||
| 	switch kind { | ||||
| 	case reflect.Array: | ||||
| 		ret = reflect.ArrayOf(t.Len(), t.Elem()) | ||||
| 	case reflect.Chan: | ||||
| 		ret = reflect.ChanOf(t.ChanDir(), t.Elem()) | ||||
| 	case reflect.Map: | ||||
| 		ret = reflect.MapOf(t.Key(), t.Elem()) | ||||
| 	case reflect.Func: | ||||
| 		nIn := t.NumIn() | ||||
| 		nOut := t.NumOut() | ||||
| 		in := make([]reflect.Type, nIn) | ||||
| 		out := make([]reflect.Type, nOut) | ||||
| 		for i := 0; i < nIn; i++ { | ||||
| 			in[i] = t.In(i) | ||||
| 		} | ||||
| 		for i := 0; i < nOut; i++ { | ||||
| 			out[i] = t.Out(i) | ||||
| 		} | ||||
| 		ret = reflect.FuncOf(in, out, t.IsVariadic()) | ||||
| 	case reflect.Interface: | ||||
| 		// not supported | ||||
| 	case reflect.Ptr: | ||||
| 		ret = reflect.PtrTo(t.Elem()) | ||||
| 	case reflect.Slice: | ||||
| 		ret = reflect.SliceOf(t.Elem()) | ||||
| 	case reflect.Struct: | ||||
| 		// only partially supported: embedded fields | ||||
| 		// and unexported fields may cause panic in reflect.StructOf() | ||||
| 		defer func() { | ||||
| 			// if a panic happens, return t unmodified | ||||
| 			if recover() != nil && ret == nil { | ||||
| 				ret = t | ||||
| 			} | ||||
| 		}() | ||||
| 		n := t.NumField() | ||||
| 		fields := make([]reflect.StructField, n) | ||||
| 		for i := 0; i < n; i++ { | ||||
| 			fields[i] = t.Field(i) | ||||
| 		} | ||||
| 		ret = reflect.StructOf(fields) | ||||
| 	} | ||||
| 	return ret | ||||
| } | ||||
|  | ||||
| // TryCast works similar to `v2, ok := v.(T)` | ||||
| // Except it works through type alias' | ||||
| func TryCast[T any](v any) (T, bool) { | ||||
|  | ||||
| 	underlying := Underlying(reflect.TypeOf(v)) | ||||
|  | ||||
| 	def := *new(T) | ||||
|  | ||||
| 	if underlying != Underlying(reflect.TypeOf(def)) { | ||||
| 		return def, false | ||||
| 	} | ||||
|  | ||||
| 	r1 := reflect.ValueOf(v) | ||||
|  | ||||
| 	if !r1.CanConvert(underlying) { | ||||
| 		return def, false | ||||
| 	} | ||||
|  | ||||
| 	r2 := r1.Convert(underlying) | ||||
|  | ||||
| 	r3 := r2.Interface() | ||||
|  | ||||
| 	r4, ok := r3.(T) | ||||
| 	if !ok { | ||||
| 		return def, false | ||||
| 	} | ||||
|  | ||||
| 	return r4, true | ||||
| } | ||||
| @@ -22,6 +22,31 @@ func Max[T langext.OrderedConstraint](v1 T, v2 T) T { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func Max3[T langext.OrderedConstraint](v1 T, v2 T, v3 T) T { | ||||
| 	result := v1 | ||||
| 	if v2 > result { | ||||
| 		result = v2 | ||||
| 	} | ||||
| 	if v3 > result { | ||||
| 		result = v3 | ||||
| 	} | ||||
| 	return result | ||||
| } | ||||
|  | ||||
| func Max4[T langext.OrderedConstraint](v1 T, v2 T, v3 T, v4 T) T { | ||||
| 	result := v1 | ||||
| 	if v2 > result { | ||||
| 		result = v2 | ||||
| 	} | ||||
| 	if v3 > result { | ||||
| 		result = v3 | ||||
| 	} | ||||
| 	if v4 > result { | ||||
| 		result = v4 | ||||
| 	} | ||||
| 	return result | ||||
| } | ||||
|  | ||||
| func Min[T langext.OrderedConstraint](v1 T, v2 T) T { | ||||
| 	if v1 < v2 { | ||||
| 		return v1 | ||||
| @@ -30,6 +55,31 @@ func Min[T langext.OrderedConstraint](v1 T, v2 T) T { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func Min3[T langext.OrderedConstraint](v1 T, v2 T, v3 T) T { | ||||
| 	result := v1 | ||||
| 	if v2 < result { | ||||
| 		result = v2 | ||||
| 	} | ||||
| 	if v3 < result { | ||||
| 		result = v3 | ||||
| 	} | ||||
| 	return result | ||||
| } | ||||
|  | ||||
| func Min4[T langext.OrderedConstraint](v1 T, v2 T, v3 T, v4 T) T { | ||||
| 	result := v1 | ||||
| 	if v2 < result { | ||||
| 		result = v2 | ||||
| 	} | ||||
| 	if v3 < result { | ||||
| 		result = v3 | ||||
| 	} | ||||
| 	if v4 < result { | ||||
| 		result = v4 | ||||
| 	} | ||||
| 	return result | ||||
| } | ||||
|  | ||||
| func Abs[T langext.NumberConstraint](v T) T { | ||||
| 	if v < 0 { | ||||
| 		return -v | ||||
|   | ||||
							
								
								
									
										49
									
								
								mongoext/pipeline.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								mongoext/pipeline.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,49 @@ | ||||
| package mongoext | ||||
|  | ||||
| import ( | ||||
| 	"go.mongodb.org/mongo-driver/bson" | ||||
| 	"go.mongodb.org/mongo-driver/mongo" | ||||
| ) | ||||
|  | ||||
| // FixTextSearchPipeline moves {$match:{$text:{$search}}} entries to the front of the pipeline (otherwise its an mongo error) | ||||
| func FixTextSearchPipeline(pipeline mongo.Pipeline) mongo.Pipeline { | ||||
|  | ||||
| 	dget := func(v bson.D, k string) (bson.M, bool) { | ||||
| 		for _, e := range v { | ||||
| 			if e.Key == k { | ||||
| 				if mv, ok := e.Value.(bson.M); ok { | ||||
| 					return mv, true | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 		return nil, false | ||||
| 	} | ||||
| 	mget := func(v bson.M, k string) (bson.M, bool) { | ||||
| 		for ekey, eval := range v { | ||||
| 			if ekey == k { | ||||
| 				if mv, ok := eval.(bson.M); ok { | ||||
| 					return mv, true | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 		return nil, false | ||||
| 	} | ||||
|  | ||||
| 	result := make([]bson.D, 0, len(pipeline)) | ||||
|  | ||||
| 	for _, entry := range pipeline { | ||||
|  | ||||
| 		if v0, ok := dget(entry, "$match"); ok { | ||||
| 			if v1, ok := mget(v0, "$text"); ok { | ||||
| 				if _, ok := v1["$search"]; ok { | ||||
| 					result = append([]bson.D{entry}, result...) | ||||
| 					continue | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		result = append(result, entry) | ||||
| 	} | ||||
|  | ||||
| 	return result | ||||
| } | ||||
							
								
								
									
										30
									
								
								mongoext/projections.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								mongoext/projections.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,30 @@ | ||||
| package mongoext | ||||
|  | ||||
| import ( | ||||
| 	"go.mongodb.org/mongo-driver/bson" | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| // ProjectionFromStruct automatically generated a mongodb projection for a struct | ||||
| // This way you can pretty much always write | ||||
| // `options.FindOne().SetProjection(mongoutils.ProjectionFromStruct(...your_model...))` | ||||
| // to only get the data from mongodb that you will actually use in the later decode step | ||||
| func ProjectionFromStruct(obj interface{}) bson.M { | ||||
| 	v := reflect.ValueOf(obj) | ||||
| 	t := v.Type() | ||||
|  | ||||
| 	result := bson.M{} | ||||
|  | ||||
| 	for i := 0; i < v.NumField(); i++ { | ||||
| 		tag := t.Field(i).Tag.Get("bson") | ||||
| 		if tag == "" { | ||||
| 			continue | ||||
| 		} | ||||
| 		tag = strings.Split(tag, ",")[0] | ||||
|  | ||||
| 		result[tag] = 1 | ||||
| 	} | ||||
|  | ||||
| 	return result | ||||
| } | ||||
| @@ -3,29 +3,23 @@ package mongoext | ||||
| import ( | ||||
| 	"go.mongodb.org/mongo-driver/bson" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsoncodec" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonrw" | ||||
| 	"gogs.mikescher.com/BlackForestBytes/goext/rfctime" | ||||
| 	"reflect" | ||||
| ) | ||||
|  | ||||
| func CreateGoExtBsonRegistry() *bsoncodec.Registry { | ||||
| 	var primitiveCodecs bson.PrimitiveCodecs | ||||
|  | ||||
| 	rb := bsoncodec.NewRegistryBuilder() | ||||
|  | ||||
| 	rb.RegisterTypeDecoder(reflect.TypeOf(rfctime.RFC3339Time{}), rfctime.RFC3339Time{}) | ||||
| 	rb.RegisterTypeDecoder(reflect.TypeOf(&rfctime.RFC3339Time{}), rfctime.RFC3339Time{}) | ||||
|  | ||||
| 	rb.RegisterTypeDecoder(reflect.TypeOf(rfctime.RFC3339NanoTime{}), rfctime.RFC3339NanoTime{}) | ||||
| 	rb.RegisterTypeDecoder(reflect.TypeOf(&rfctime.RFC3339NanoTime{}), rfctime.RFC3339NanoTime{}) | ||||
|  | ||||
| 	bsoncodec.DefaultValueEncoders{}.RegisterDefaultEncoders(rb) | ||||
| 	bsoncodec.DefaultValueDecoders{}.RegisterDefaultDecoders(rb) | ||||
|  | ||||
| 	rb.RegisterTypeDecoder(reflect.TypeOf(rfctime.RFC3339Time{}), rfctime.RFC3339Time{}) | ||||
|  | ||||
| 	primitiveCodecs.RegisterPrimitiveCodecs(rb) | ||||
| 	bson.PrimitiveCodecs{}.RegisterPrimitiveCodecs(rb) | ||||
|  | ||||
| 	return rb.Build() | ||||
| } | ||||
|  | ||||
| func encodeRFC3339Time(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) { | ||||
|  | ||||
| } | ||||
|  | ||||
| func decodeRFC3339Time(ec bsoncodec.EncodeContext, vr bsonrw.ValueReader, val reflect.Value) { | ||||
|  | ||||
| } | ||||
|   | ||||
| @@ -70,7 +70,11 @@ func (t *RFC3339Time) UnmarshalText(data []byte) error { | ||||
|  | ||||
| func (t *RFC3339Time) UnmarshalBSONValue(bt bsontype.Type, data []byte) error { | ||||
| 	if bt == bsontype.Null { | ||||
| 		//t = nil | ||||
| 		// we can't set nil in UnmarshalBSONValue (so we use default(struct)) | ||||
| 		// Use mongoext.CreateGoExtBsonRegistry if you need to unmarsh pointer values | ||||
| 		// https://stackoverflow.com/questions/75167597 | ||||
| 		// https://jira.mongodb.org/browse/GODRIVER-2252 | ||||
| 		*t = RFC3339Time{} | ||||
| 		return nil | ||||
| 	} | ||||
| 	if bt != bsontype.DateTime { | ||||
|   | ||||
| @@ -5,7 +5,10 @@ import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"go.mongodb.org/mongo-driver/bson" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsoncodec" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonrw" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| 	"reflect" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| @@ -67,7 +70,11 @@ func (t *RFC3339NanoTime) UnmarshalText(data []byte) error { | ||||
|  | ||||
| func (t *RFC3339NanoTime) UnmarshalBSONValue(bt bsontype.Type, data []byte) error { | ||||
| 	if bt == bsontype.Null { | ||||
| 		//t = nil | ||||
| 		// we can't set nil in UnmarshalBSONValue (so we use default(struct)) | ||||
| 		// Use mongoext.CreateGoExtBsonRegistry if you need to unmarsh pointer values | ||||
| 		// https://stackoverflow.com/questions/75167597 | ||||
| 		// https://jira.mongodb.org/browse/GODRIVER-2252 | ||||
| 		*t = RFC3339NanoTime{} | ||||
| 		return nil | ||||
| 	} | ||||
| 	if bt != bsontype.DateTime { | ||||
| @@ -86,6 +93,38 @@ func (t RFC3339NanoTime) MarshalBSONValue() (bsontype.Type, []byte, error) { | ||||
| 	return bson.MarshalValue(time.Time(t)) | ||||
| } | ||||
|  | ||||
| func (t RFC3339NanoTime) DecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { | ||||
| 	if val.Kind() == reflect.Ptr && val.IsNil() { | ||||
| 		if !val.CanSet() { | ||||
| 			return errors.New("ValueUnmarshalerDecodeValue") | ||||
| 		} | ||||
| 		val.Set(reflect.New(val.Type().Elem())) | ||||
| 	} | ||||
|  | ||||
| 	tp, src, err := bsonrw.Copier{}.CopyValueToBytes(vr) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if val.Kind() == reflect.Ptr && len(src) == 0 { | ||||
| 		val.Set(reflect.Zero(val.Type())) | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	err = t.UnmarshalBSONValue(tp, src) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if val.Kind() == reflect.Ptr { | ||||
| 		val.Set(reflect.ValueOf(&t)) | ||||
| 	} else { | ||||
| 		val.Set(reflect.ValueOf(t)) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (t RFC3339NanoTime) Serialize() string { | ||||
| 	return t.Time().Format(t.FormatStr()) | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user