diff --git a/cache.go b/cache.go index 5248d81..90517d3 100644 --- a/cache.go +++ b/cache.go @@ -371,5 +371,12 @@ func UnsafeLocationFilter(location string) func(*SKU) bool { } } +// IncludesFilter returns a FilterFn that checks if the SKU is included in the provided list of SKUs. +func IncludesFilter(skuList []SKU) func(*SKU) bool { + return func(s *SKU) bool { + return s.MemberOf(skuList) + } +} + // MapFn is a convenience type for mapping. type MapFn func(*SKU) SKU diff --git a/sku.go b/sku.go index 565f82f..436ad93 100644 --- a/sku.go +++ b/sku.go @@ -572,3 +572,13 @@ func (s *SKU) Equal(other *SKU) bool { localErr != nil && otherErr != nil } + +// MemberOf returns true if the SKU's name is in the list of SKUs. +func (s *SKU) MemberOf(skuList []SKU) bool { + for _, sku := range skuList { + if s.GetName() == sku.GetName() { + return true + } + } + return false +} diff --git a/sku_test.go b/sku_test.go index 7e91e2e..ec79d0a 100644 --- a/sku_test.go +++ b/sku_test.go @@ -517,3 +517,54 @@ func Test_SKU_HasCapabilityInZone(t *testing.T) { }) } } + +// Test_SKU_MemberOf tests the SKU MemberOf method +func Test_SKU_Includes(t *testing.T) { + cases := map[string]struct { + skuList []SKU + sku SKU + expect bool + }{ + "empty list should not include": { + skuList: []SKU{}, + sku: SKU{ + Name: to.StringPtr("foo"), + }, + expect: false, + }, + "missing name should not include": { + skuList: []SKU{ + { + Name: to.StringPtr("foo"), + }, + }, + sku: SKU{ + Name: to.StringPtr("bar"), + }, + expect: false, + }, + "name is included": { + skuList: []SKU{ + { + Name: to.StringPtr("foo"), + }, + { + Name: to.StringPtr("bar"), + }, + }, + sku: SKU{ + Name: to.StringPtr("bar"), + }, + expect: true, + }, + } + for name, tc := range cases { + tc := tc + t.Run(name, func(t *testing.T) { + sku := SKU(tc.sku) + if diff := cmp.Diff(tc.expect, sku.MemberOf(tc.skuList)); diff != "" { + t.Error(diff) + } + }) + } +}