Fix some cases of overlapping $merge races

This commit is contained in:
Ian Gulliver
2023-07-27 16:48:46 -07:00
parent 1840b394b1
commit 6044e682df
8 changed files with 58 additions and 17 deletions

View File

@@ -2,8 +2,6 @@ package bkl
import ( import (
"fmt" "fmt"
"github.com/gopatchy/bkl/polyfill"
) )
func merge(dst any, src any) (any, error) { func merge(dst any, src any) (any, error) {
@@ -40,13 +38,13 @@ func mergeMap(dst map[string]any, src any) (map[string]any, error) {
} }
func mergeMapMap(dst map[string]any, src map[string]any) (map[string]any, error) { func mergeMapMap(dst map[string]any, src map[string]any) (map[string]any, error) {
replace, src := popMapBoolValue(src, "$replace", true) replace, found := getMapBoolValue(src, "$replace")
if replace { if found && replace {
delete(src, "$replace")
return src, nil return src, nil
} }
dst = polyfill.MapsClone(dst)
for k, v := range src { for k, v := range src {
existing, found := dst[k] existing, found := dst[k]

View File

@@ -20,6 +20,10 @@ func SlicesClone[S ~[]E, E any](s S) S { //nolint:ireturn
return slices.Clone(s) return slices.Clone(s)
} }
func SlicesDeleteFunc[S ~[]E, E any](s S, del func(E) bool) S { //nolint:ireturn
return slices.DeleteFunc(s, del)
}
func SlicesReverse[S ~[]E, E any](s S) { func SlicesReverse[S ~[]E, E any](s S) {
slices.Reverse(s) slices.Reverse(s)
} }

View File

@@ -36,6 +36,28 @@ func SlicesClone[S ~[]E, E any](s S) S { //nolint:ireturn
return append(S([]E{}), s...) return append(S([]E{}), s...)
} }
// Copied from go1.21 slices
func SlicesDeleteFunc[S ~[]E, E any](s S, del func(E) bool) S { //nolint:ireturn
for i, v := range s {
if del(v) {
j := i
for i++; i < len(s); i++ {
v = s[i]
if !del(v) {
s[j] = v
j++
}
}
return s[:j]
}
}
return s
}
// Copied from go1.21 slices // Copied from go1.21 slices
func SlicesReverse[S ~[]E, E any](s S) { func SlicesReverse[S ~[]E, E any](s S) {
for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 { for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 {

View File

@@ -9,6 +9,8 @@ func Process(obj, mergeFrom any, mergeFromDocs []any) (any, error) {
return process(obj, mergeFrom, mergeFromDocs, 0) return process(obj, mergeFrom, mergeFromDocs, 0)
} }
// process() and descendants intentionally mutate obj to handle chained
// references
func process(obj, mergeFrom any, mergeFromDocs []any, depth int) (any, error) { func process(obj, mergeFrom any, mergeFromDocs []any, depth int) (any, error) {
if depth > 1000 { if depth > 1000 {
return nil, fmt.Errorf("%#v: %w", obj, ErrCircularRef) return nil, fmt.Errorf("%#v: %w", obj, ErrCircularRef)
@@ -30,8 +32,10 @@ func process(obj, mergeFrom any, mergeFromDocs []any, depth int) (any, error) {
} }
func processMap(obj map[string]any, mergeFrom any, mergeFromDocs []any, depth int) (any, error) { func processMap(obj map[string]any, mergeFrom any, mergeFromDocs []any, depth int) (any, error) {
m, obj := popMapValue(obj, "$merge") m := obj["$merge"]
if m != nil { if m != nil {
delete(obj, "$merge")
in, err := get(mergeFrom, mergeFromDocs, m) in, err := get(mergeFrom, mergeFromDocs, m)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -45,8 +49,10 @@ func processMap(obj map[string]any, mergeFrom any, mergeFromDocs []any, depth in
return process(next, mergeFrom, mergeFromDocs, depth) return process(next, mergeFrom, mergeFromDocs, depth)
} }
m, obj = popMapValue(obj, "$replace") m = obj["$replace"]
if m != nil { if m != nil {
delete(obj, "$replace")
next, err := get(mergeFrom, mergeFromDocs, m) next, err := get(mergeFrom, mergeFromDocs, m)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -55,27 +61,28 @@ func processMap(obj map[string]any, mergeFrom any, mergeFromDocs []any, depth in
return process(next, mergeFrom, mergeFromDocs, depth) return process(next, mergeFrom, mergeFromDocs, depth)
} }
output, obj := popMapBoolValue(obj, "$output", false) output, found := getMapBoolValue(obj, "$output")
if output { if found && !output {
return nil, nil return nil, nil
} }
encode, obj := popMapStringValue(obj, "$encode") encode := getMapStringValue(obj, "$encode")
if encode != "" {
delete(obj, "$encode")
}
obj, err := filterMap(obj, func(k string, v any) (map[string]any, error) { for k, v := range obj {
v2, err := process(v, mergeFrom, mergeFromDocs, depth) v2, err := process(v, mergeFrom, mergeFromDocs, depth)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if v2 == nil { if v2 == nil {
return nil, nil delete(obj, k)
continue
} }
return map[string]any{k: v2}, nil obj[k] = v2
})
if err != nil {
return nil, err
} }
if encode != "" { if encode != "" {

View File

@@ -0,0 +1 @@
$merge: c

4
tests/merge-race/a.yaml Normal file
View File

@@ -0,0 +1,4 @@
a: $required
b: $merge:a
c:
a: 1

1
tests/merge-race/cmd Normal file
View File

@@ -0,0 +1 @@
bkl -f yaml a.b.yaml

View File

@@ -0,0 +1,4 @@
a: 1
b: 1
c:
a: 1