Fix some cases of overlapping $merge races

This commit is contained in:
Ian Gulliver
2023-07-27 16:48:46 -07:00
parent 1840b394b1
commit 6044e682df
8 changed files with 58 additions and 17 deletions

View File

@@ -2,8 +2,6 @@ package bkl
import (
"fmt"
"github.com/gopatchy/bkl/polyfill"
)
func merge(dst any, src any) (any, error) {
@@ -40,13 +38,13 @@ func mergeMap(dst map[string]any, src any) (map[string]any, error) {
}
func mergeMapMap(dst map[string]any, src map[string]any) (map[string]any, error) {
replace, src := popMapBoolValue(src, "$replace", true)
if replace {
replace, found := getMapBoolValue(src, "$replace")
if found && replace {
delete(src, "$replace")
return src, nil
}
dst = polyfill.MapsClone(dst)
for k, v := range src {
existing, found := dst[k]

View File

@@ -20,6 +20,10 @@ func SlicesClone[S ~[]E, E any](s S) S { //nolint:ireturn
return slices.Clone(s)
}
func SlicesDeleteFunc[S ~[]E, E any](s S, del func(E) bool) S { //nolint:ireturn
return slices.DeleteFunc(s, del)
}
func SlicesReverse[S ~[]E, E any](s S) {
slices.Reverse(s)
}

View File

@@ -36,6 +36,28 @@ func SlicesClone[S ~[]E, E any](s S) S { //nolint:ireturn
return append(S([]E{}), s...)
}
// Copied from go1.21 slices
func SlicesDeleteFunc[S ~[]E, E any](s S, del func(E) bool) S { //nolint:ireturn
for i, v := range s {
if del(v) {
j := i
for i++; i < len(s); i++ {
v = s[i]
if !del(v) {
s[j] = v
j++
}
}
return s[:j]
}
}
return s
}
// Copied from go1.21 slices
func SlicesReverse[S ~[]E, E any](s S) {
for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 {

View File

@@ -9,6 +9,8 @@ func Process(obj, mergeFrom any, mergeFromDocs []any) (any, error) {
return process(obj, mergeFrom, mergeFromDocs, 0)
}
// process() and descendants intentionally mutate obj to handle chained
// references
func process(obj, mergeFrom any, mergeFromDocs []any, depth int) (any, error) {
if depth > 1000 {
return nil, fmt.Errorf("%#v: %w", obj, ErrCircularRef)
@@ -30,8 +32,10 @@ func process(obj, mergeFrom any, mergeFromDocs []any, depth int) (any, error) {
}
func processMap(obj map[string]any, mergeFrom any, mergeFromDocs []any, depth int) (any, error) {
m, obj := popMapValue(obj, "$merge")
m := obj["$merge"]
if m != nil {
delete(obj, "$merge")
in, err := get(mergeFrom, mergeFromDocs, m)
if err != nil {
return nil, err
@@ -45,8 +49,10 @@ func processMap(obj map[string]any, mergeFrom any, mergeFromDocs []any, depth in
return process(next, mergeFrom, mergeFromDocs, depth)
}
m, obj = popMapValue(obj, "$replace")
m = obj["$replace"]
if m != nil {
delete(obj, "$replace")
next, err := get(mergeFrom, mergeFromDocs, m)
if err != nil {
return nil, err
@@ -55,27 +61,28 @@ func processMap(obj map[string]any, mergeFrom any, mergeFromDocs []any, depth in
return process(next, mergeFrom, mergeFromDocs, depth)
}
output, obj := popMapBoolValue(obj, "$output", false)
if output {
output, found := getMapBoolValue(obj, "$output")
if found && !output {
return nil, nil
}
encode, obj := popMapStringValue(obj, "$encode")
encode := getMapStringValue(obj, "$encode")
if encode != "" {
delete(obj, "$encode")
}
obj, err := filterMap(obj, func(k string, v any) (map[string]any, error) {
for k, v := range obj {
v2, err := process(v, mergeFrom, mergeFromDocs, depth)
if err != nil {
return nil, err
}
if v2 == nil {
return nil, nil
delete(obj, k)
continue
}
return map[string]any{k: v2}, nil
})
if err != nil {
return nil, err
obj[k] = v2
}
if encode != "" {

View File

@@ -0,0 +1 @@
$merge: c

4
tests/merge-race/a.yaml Normal file
View File

@@ -0,0 +1,4 @@
a: $required
b: $merge:a
c:
a: 1

1
tests/merge-race/cmd Normal file
View File

@@ -0,0 +1 @@
bkl -f yaml a.b.yaml

View File

@@ -0,0 +1,4 @@
a: 1
b: 1
c:
a: 1