diff --git a/flatten.go b/flatten.go index d7c09da..7481284 100644 --- a/flatten.go +++ b/flatten.go @@ -16,9 +16,10 @@ import ( // If you use Middle, you will probably leave Before & After blank, and vice-versa. // See examples in flatten_test.go and the "Default styles" here. type SeparatorStyle struct { - Before string // Prepend to key - Middle string // Add between keys - After string // Append to key + Before string // Prepend to key + Middle string // Add between keys + After string // Append to key + DoNotFlattenPrimitiveArrays bool // Don't flatten arrays of primitives (numbers, strings, etc.) } // Default styles @@ -86,16 +87,33 @@ func FlattenString(nestedstr, prefix string, style SeparatorStyle) (string, erro } func flatten(top bool, flatMap map[string]interface{}, nested interface{}, prefix string, style SeparatorStyle) error { + allPrimitives := func(arr []interface{}) bool { + for _, item := range arr { + switch item.(type) { + case string, int32, int64, float32, float64, json.Number: + continue + default: + return false + } + } + return true + } + assign := func(newKey string, v interface{}) error { + shouldFlatten := false + switch v.(type) { - case map[string]interface{}, []interface{}: - if err := flatten(false, flatMap, v, newKey, style); err != nil { - return err - } - default: - flatMap[newKey] = v + case []interface{}: + shouldFlatten = !(style.DoNotFlattenPrimitiveArrays && allPrimitives(v.([]interface{}))) + case map[string]interface{}: + shouldFlatten = true + } + + if shouldFlatten { + return flatten(false, flatMap, v, newKey, style) } + flatMap[newKey] = v return nil } diff --git a/flatten_test.go b/flatten_test.go index f81ebb6..865e25d 100644 --- a/flatten_test.go +++ b/flatten_test.go @@ -30,6 +30,11 @@ func TestFlatten(t *testing.T) { "d": "other", "e": "another" } + ], + "ilist": [ + 1, + 2, + 3 ] }, "number": 1.4567, @@ -43,6 +48,9 @@ func TestFlatten(t *testing.T) { "n1.alist.2": "c", "n1.alist.3.d": "other", "n1.alist.3.e": "another", + "n1.ilist.0": float64(1), + "n1.ilist.1": float64(2), + "n1.ilist.2": float64(3), "number": 1.4567, "bool": true, }, @@ -156,6 +164,76 @@ func TestFlatten(t *testing.T) { "", UnderscoreStyle, }, + { + `{ + "foo": { + "jim":"bean" + }, + "fee": "bar", + "n1": { + "alist": [ + "a", + "b", + "c" + ], + "ilist": [ + 1, + 2, + 3 + ] + }, + "number": 1.4567, + "bool": true + }`, + map[string]interface{}{ + "foo.jim": "bean", + "fee": "bar", + "n1.alist": []interface{}{ + "a", + "b", + "c", + }, + "n1.ilist": []interface{}{ + float64(1), + float64(2), + float64(3), + }, + "number": 1.4567, + "bool": true, + }, + "", + SeparatorStyle{ + Middle: ".", + DoNotFlattenPrimitiveArrays: true, + }, + }, + { + `{ + "foo": [ + 1, + 2, + 3 + ], + "fee": "bar", + "number": 1.4567, + "bool": true + }`, + map[string]interface{}{ + "foo": []interface{}{ + float64(1), + float64(2), + float64(3), + }, + "fee": "bar", + "number": 1.4567, + "bool": true, + }, + "", + SeparatorStyle{ + Middle: ".", + DoNotFlattenPrimitiveArrays: true, + }, + }, } for i, test := range cases { @@ -171,7 +249,7 @@ func TestFlatten(t *testing.T) { continue } if !reflect.DeepEqual(got, test.want) { - t.Errorf("%d: mismatch, got: %v wanted: %v", i+1, got, test.want) + t.Errorf("%d: mismatch, got: %#v wanted: %#v", i+1, got, test.want) } } } @@ -306,7 +384,7 @@ func TestFlattenString(t *testing.T) { } if got != strings.Map(nixws, test.want) { - t.Errorf("%d: mismatch, got: %v wanted: %v", i+1, got, test.want) + t.Errorf("%d: mismatch, got: %#v wanted: %#v", i+1, got, test.want) } } }