diff --git a/api/api_gomux.go b/api/api_gomux.go index e016027..bd38023 100644 --- a/api/api_gomux.go +++ b/api/api_gomux.go @@ -576,7 +576,7 @@ func (api *API) getEntitiesHandler(w http.ResponseWriter, r *http.Request) { // Query data store (instrumented) rc.inst.Start("db") - entities, err := api.es.WithContext(ctx).ReadEntities(rc.entityType, q, f) + entities, err := api.es.ReadEntities(ctx, rc.entityType, q, f) rc.inst.Stop("db") if err != nil { api.readError(rc, w, err) @@ -650,7 +650,7 @@ func (api *API) postEntitiesHandler(w http.ResponseWriter, r *http.Request) { } // Write new entities to data store - ids, err = api.es.WithContext(ctx).CreateEntities(rc.wo, entities) + ids, err = api.es.CreateEntities(ctx, rc.wo, entities) rc.gm.Inc(metrics.Created, int64(len(ids))) reply: @@ -717,7 +717,7 @@ func (api *API) putEntitiesHandler(w http.ResponseWriter, r *http.Request) { } // Patch all entities matching query - entities, err = api.es.WithContext(ctx).UpdateEntities(rc.wo, q, patch) + entities, err = api.es.UpdateEntities(ctx, rc.wo, q, patch) rc.gm.Val(metrics.UpdateBulk, int64(len(entities))) rc.gm.Inc(metrics.Updated, int64(len(entities))) @@ -764,7 +764,7 @@ func (api *API) deleteEntitiesHandler(w http.ResponseWriter, r *http.Request) { } // Delete entities, returns the deleted entities - entities, err = api.es.WithContext(ctx).DeleteEntities(rc.wo, q) + entities, err = api.es.DeleteEntities(ctx, rc.wo, q) rc.gm.Val(metrics.DeleteBulk, int64(len(entities))) rc.gm.Inc(metrics.Deleted, int64(len(entities))) @@ -805,7 +805,7 @@ func (api *API) getEntityHandler(w http.ResponseWriter, r *http.Request) { // Read the entity by ID q, _ := query.Translate("_id=" + rc.entityId) - entities, err := api.es.WithContext(ctx).ReadEntities(rc.entityType, q, f) + entities, err := api.es.ReadEntities(ctx, rc.entityType, q, f) if err != nil { api.readError(rc, w, err) return @@ -838,7 +838,7 @@ func (api *API) getLabelsHandler(w http.ResponseWriter, r *http.Request) { rc.gm.Inc(metrics.ReadLabels, 1) // specific read type q, _ := query.Translate("_id=" + rc.entityId) - entities, err := api.es.WithContext(ctx).ReadEntities(rc.entityType, q, etre.QueryFilter{}) + entities, err := api.es.ReadEntities(ctx, rc.entityType, q, etre.QueryFilter{}) if err != nil { api.readError(rc, w, err) return @@ -893,7 +893,7 @@ func (api *API) postEntityHandler(w http.ResponseWriter, r *http.Request) { } // Create new entity - ids, err = api.es.WithContext(ctx).CreateEntities(rc.wo, []etre.Entity{newEntity}) + ids, err = api.es.CreateEntities(ctx, rc.wo, []etre.Entity{newEntity}) if err == nil { rc.gm.Inc(metrics.Created, 1) } @@ -947,7 +947,7 @@ func (api *API) putEntityHandler(w http.ResponseWriter, r *http.Request) { } // Patch one entity by ID - entities, err = api.es.WithContext(ctx).UpdateEntities(rc.wo, q, patch) + entities, err = api.es.UpdateEntities(ctx, rc.wo, q, patch) if err != nil { goto reply } else if len(entities) == 0 { @@ -991,7 +991,7 @@ func (api *API) deleteEntityHandler(w http.ResponseWriter, r *http.Request) { q, _ := query.Translate("_id=" + rc.entityId) // Delete one entity by ID - entities, err = api.es.WithContext(ctx).DeleteEntities(rc.wo, q) + entities, err = api.es.DeleteEntities(ctx, rc.wo, q) if err != nil { goto reply } else if len(entities) == 0 { @@ -1041,7 +1041,7 @@ func (api *API) deleteLabelHandler(w http.ResponseWriter, r *http.Request) { } // Delete label from entity - diff, err = api.es.WithContext(ctx).DeleteLabel(rc.wo, label) + diff, err = api.es.DeleteLabel(ctx, rc.wo, label) if err != nil { if err == etre.ErrEntityNotFound { err = nil // delete is idempotent diff --git a/api/api_test.go b/api/api_test.go index bea8826..b3665e1 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -198,11 +198,8 @@ func TestClientQueryTimeout(t *testing.T) { // and plumbed all the way down to the entity.Store context var gotCtx context.Context store := mock.EntityStore{} - store.WithContextFunc = func(ctx context.Context) entity.Store { + store.ReadEntitiesFunc = func(ctx context.Context, entityType string, q query.Query, f etre.QueryFilter) ([]etre.Entity, error) { gotCtx = ctx - return store - } - store.ReadEntitiesFunc = func(entityType string, q query.Query, f etre.QueryFilter) ([]etre.Entity, error) { return testEntitiesWithObjectIDs[0:1], nil } server := setup(t, defaultConfig, store) @@ -244,24 +241,25 @@ func TestContextPropagation(t *testing.T) { // Make sure context values from the request are propagated all the way down to the entity.Store context var gotCtx context.Context store := mock.EntityStore{} - store.WithContextFunc = func(ctx context.Context) entity.Store { - gotCtx = ctx - return store - } // We're going to test all operations, so we need to set all of these funcs - store.ReadEntitiesFunc = func(entityType string, q query.Query, f etre.QueryFilter) ([]etre.Entity, error) { + store.ReadEntitiesFunc = func(ctx context.Context, entityType string, q query.Query, f etre.QueryFilter) ([]etre.Entity, error) { + gotCtx = ctx return testEntitiesWithObjectIDs[0:1], nil } - store.CreateEntitiesFunc = func(op entity.WriteOp, entities []etre.Entity) ([]string, error) { + store.CreateEntitiesFunc = func(ctx context.Context, op entity.WriteOp, entities []etre.Entity) ([]string, error) { + gotCtx = ctx return []string{testEntityIds[0]}, nil } - store.UpdateEntitiesFunc = func(op entity.WriteOp, q query.Query, e etre.Entity) ([]etre.Entity, error) { + store.UpdateEntitiesFunc = func(ctx context.Context, op entity.WriteOp, q query.Query, e etre.Entity) ([]etre.Entity, error) { + gotCtx = ctx return testEntitiesWithObjectIDs[0:1], nil } - store.DeleteEntitiesFunc = func(op entity.WriteOp, q query.Query) ([]etre.Entity, error) { + store.DeleteEntitiesFunc = func(ctx context.Context, op entity.WriteOp, q query.Query) ([]etre.Entity, error) { + gotCtx = ctx return testEntitiesWithObjectIDs[0:1], nil } - store.DeleteLabelFunc = func(op entity.WriteOp, label string) (etre.Entity, error) { + store.DeleteLabelFunc = func(ctx context.Context, op entity.WriteOp, label string) (etre.Entity, error) { + gotCtx = ctx return testEntitiesWithObjectIDs[0], nil } diff --git a/api/bulk_write_test.go b/api/bulk_write_test.go index b338bb7..011669f 100644 --- a/api/bulk_write_test.go +++ b/api/bulk_write_test.go @@ -3,6 +3,7 @@ package api_test import ( + "context" "encoding/json" "net/http" "net/url" @@ -32,7 +33,7 @@ func TestPostEntitiesOK(t *testing.T) { var gotWO entity.WriteOp var gotEntities []etre.Entity store := mock.EntityStore{ - CreateEntitiesFunc: func(wo entity.WriteOp, entities []etre.Entity) ([]string, error) { + CreateEntitiesFunc: func(ctx context.Context, wo entity.WriteOp, entities []etre.Entity) ([]string, error) { gotWO = wo gotEntities = entities return []string{"id1", "id2"}, nil @@ -95,7 +96,7 @@ func TestPostEntitiesErrors(t *testing.T) { // Most importantly: CreateEntities() should _not_ be called. created := false store := mock.EntityStore{ - CreateEntitiesFunc: func(wo entity.WriteOp, entities []etre.Entity) ([]string, error) { + CreateEntitiesFunc: func(ctx context.Context, wo entity.WriteOp, entities []etre.Entity) ([]string, error) { created = true return []string{"id1", "id2"}, nil }, @@ -246,7 +247,7 @@ func TestPutEntitiesOK(t *testing.T) { var gotQuery query.Query var gotPatch etre.Entity store := mock.EntityStore{ - UpdateEntitiesFunc: func(wo entity.WriteOp, q query.Query, patch etre.Entity) ([]etre.Entity, error) { + UpdateEntitiesFunc: func(ctx context.Context, wo entity.WriteOp, q query.Query, patch etre.Entity) ([]etre.Entity, error) { gotWO = wo gotQuery = q gotPatch = patch @@ -330,7 +331,7 @@ func TestPutEntitiesErrors(t *testing.T) { // metrics when any input is invalid. The UpdateEntities() should not be called. updated := false store := mock.EntityStore{ - UpdateEntitiesFunc: func(wo entity.WriteOp, q query.Query, patch etre.Entity) ([]etre.Entity, error) { + UpdateEntitiesFunc: func(ctx context.Context, wo entity.WriteOp, q query.Query, patch etre.Entity) ([]etre.Entity, error) { updated = true return []etre.Entity{}, nil }, @@ -496,7 +497,7 @@ func TestDeleteEntitiesOK(t *testing.T) { var gotWO entity.WriteOp var gotQuery query.Query store := mock.EntityStore{ - DeleteEntitiesFunc: func(wo entity.WriteOp, q query.Query) ([]etre.Entity, error) { + DeleteEntitiesFunc: func(ctx context.Context, wo entity.WriteOp, q query.Query) ([]etre.Entity, error) { gotWO = wo gotQuery = q return []etre.Entity{ @@ -575,7 +576,7 @@ func TestDeleteEntitiesErrors(t *testing.T) { // metrics when any input is invalid. The DeleteEntities() should not be called. deleted := false store := mock.EntityStore{ - DeleteEntitiesFunc: func(wo entity.WriteOp, q query.Query) ([]etre.Entity, error) { + DeleteEntitiesFunc: func(ctx context.Context, wo entity.WriteOp, q query.Query) ([]etre.Entity, error) { deleted = true return []etre.Entity{}, nil }, diff --git a/api/query_test.go b/api/query_test.go index b18081f..c52f50e 100644 --- a/api/query_test.go +++ b/api/query_test.go @@ -28,7 +28,7 @@ func TestQueryBasic(t *testing.T) { var gotQuery query.Query var gotFilter etre.QueryFilter store := mock.EntityStore{ - ReadEntitiesFunc: func(entityType string, q query.Query, f etre.QueryFilter) ([]etre.Entity, error) { + ReadEntitiesFunc: func(ctx context.Context, entityType string, q query.Query, f etre.QueryFilter) ([]etre.Entity, error) { gotQuery = q gotFilter = f return testEntitiesWithObjectIDs, nil @@ -171,7 +171,7 @@ func TestQueryNoMatches(t *testing.T) { // is still 200 OK in this case because there's no error. var gotQuery query.Query store := mock.EntityStore{ - ReadEntitiesFunc: func(entityType string, q query.Query, f etre.QueryFilter) ([]etre.Entity, error) { + ReadEntitiesFunc: func(ctx context.Context, entityType string, q query.Query, f etre.QueryFilter) ([]etre.Entity, error) { gotQuery = q return []etre.Entity{}, nil // no matching queries }, @@ -220,7 +220,7 @@ func TestQueryErrorsDatabaseError(t *testing.T) { // Test that GET /entities/:type?query=Q handles a database error correctly. // Db errors (and only db errors return HTTP 503 "Service Unavailable". store := mock.EntityStore{ - ReadEntitiesFunc: func(entityType string, q query.Query, f etre.QueryFilter) ([]etre.Entity, error) { + ReadEntitiesFunc: func(ctx context.Context, entityType string, q query.Query, f etre.QueryFilter) ([]etre.Entity, error) { return nil, entity.DbError{Err: fmt.Errorf("fake error"), Type: "db-read"} }, } @@ -269,7 +269,7 @@ func TestQueryErrorsNoEntityType(t *testing.T) { // You can run "../test/coverage -test.run TestQueryErrorsNoEntityType" and // see that the handler is never called. store := mock.EntityStore{ - ReadEntitiesFunc: func(entityType string, q query.Query, f etre.QueryFilter) ([]etre.Entity, error) { + ReadEntitiesFunc: func(ctx context.Context, entityType string, q query.Query, f etre.QueryFilter) ([]etre.Entity, error) { return nil, entity.DbError{Err: fmt.Errorf("fake error"), Type: "db-read"} }, } @@ -310,7 +310,7 @@ func TestQueryErrorsTimeout(t *testing.T) { // Test that GET /entities/:type?query=Q handles a database timeout correctly. // Db errors (and only db errors return HTTP 503 "Service Unavailable". store := mock.EntityStore{ - ReadEntitiesFunc: func(entityType string, q query.Query, f etre.QueryFilter) ([]etre.Entity, error) { + ReadEntitiesFunc: func(ctx context.Context, entityType string, q query.Query, f etre.QueryFilter) ([]etre.Entity, error) { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() <-ctx.Done() @@ -357,7 +357,7 @@ func TestQueryErrorsTimeout(t *testing.T) { func TestResponseCompression(t *testing.T) { // Stand up the server store := mock.EntityStore{ - ReadEntitiesFunc: func(entityType string, q query.Query, f etre.QueryFilter) ([]etre.Entity, error) { + ReadEntitiesFunc: func(ctx context.Context, entityType string, q query.Query, f etre.QueryFilter) ([]etre.Entity, error) { return testEntitiesWithObjectIDs, nil }, } diff --git a/api/single_entity_read_test.go b/api/single_entity_read_test.go index 89948b8..987c192 100644 --- a/api/single_entity_read_test.go +++ b/api/single_entity_read_test.go @@ -3,6 +3,7 @@ package api_test import ( + "context" "fmt" "net/http" "net/url" @@ -31,7 +32,7 @@ func TestGetEntityBasic(t *testing.T) { var gotQuery query.Query var gotFilter etre.QueryFilter store := mock.EntityStore{ - ReadEntitiesFunc: func(entityType string, q query.Query, f etre.QueryFilter) ([]etre.Entity, error) { + ReadEntitiesFunc: func(ctx context.Context, entityType string, q query.Query, f etre.QueryFilter) ([]etre.Entity, error) { gotQuery = q gotFilter = f return testEntitiesWithObjectIDs[0:1], nil @@ -83,7 +84,7 @@ func TestGetEntityReturnLabels(t *testing.T) { // that the URL param "labels=" is processed and passed along to the entity.Store. var gotFilter etre.QueryFilter store := mock.EntityStore{ - ReadEntitiesFunc: func(entityType string, q query.Query, f etre.QueryFilter) ([]etre.Entity, error) { + ReadEntitiesFunc: func(ctx context.Context, entityType string, q query.Query, f etre.QueryFilter) ([]etre.Entity, error) { gotFilter = f return testEntitiesWithObjectIDs[0:1], nil }, @@ -129,7 +130,7 @@ func TestGetEntityNotFound(t *testing.T) { // We simulate this by making ReadEntities() below return an empty list which // the real entity.Store() does when no entity exists with the given _id. store := mock.EntityStore{ - ReadEntitiesFunc: func(entityType string, q query.Query, f etre.QueryFilter) ([]etre.Entity, error) { + ReadEntitiesFunc: func(ctx context.Context, entityType string, q query.Query, f etre.QueryFilter) ([]etre.Entity, error) { return []etre.Entity{}, nil }, } @@ -163,7 +164,7 @@ func TestGetEntityErrors(t *testing.T) { read := false var dbError error store := mock.EntityStore{ - ReadEntitiesFunc: func(entityType string, q query.Query, f etre.QueryFilter) ([]etre.Entity, error) { + ReadEntitiesFunc: func(ctx context.Context, entityType string, q query.Query, f etre.QueryFilter) ([]etre.Entity, error) { read = true return nil, dbError }, @@ -257,7 +258,7 @@ func TestGetEntityLabels(t *testing.T) { var gotQuery query.Query var gotFilter etre.QueryFilter store := mock.EntityStore{ - ReadEntitiesFunc: func(entityType string, q query.Query, f etre.QueryFilter) ([]etre.Entity, error) { + ReadEntitiesFunc: func(ctx context.Context, entityType string, q query.Query, f etre.QueryFilter) ([]etre.Entity, error) { gotQuery = q gotFilter = f return testEntitiesWithObjectIDs[0:1], nil diff --git a/api/single_entity_write_test.go b/api/single_entity_write_test.go index eb4a4af..b54a444 100644 --- a/api/single_entity_write_test.go +++ b/api/single_entity_write_test.go @@ -3,6 +3,7 @@ package api_test import ( + "context" "encoding/json" "fmt" "net/http" @@ -33,7 +34,7 @@ func TestPostEntityOK(t *testing.T) { var gotWO entity.WriteOp var gotEntities []etre.Entity store := mock.EntityStore{ - CreateEntitiesFunc: func(wo entity.WriteOp, entities []etre.Entity) ([]string, error) { + CreateEntitiesFunc: func(ctx context.Context, wo entity.WriteOp, entities []etre.Entity) ([]string, error) { gotWO = wo gotEntities = entities return []string{"id1"}, nil @@ -96,7 +97,7 @@ func TestPostEntityDuplicate(t *testing.T) { // Test that POST /entities/:type returns HTTP 403 Conflict on duplicate // which we simulate by returning what entity.Store would: store := mock.EntityStore{ - CreateEntitiesFunc: func(wo entity.WriteOp, entities []etre.Entity) ([]string, error) { + CreateEntitiesFunc: func(ctx context.Context, wo entity.WriteOp, entities []etre.Entity) ([]string, error) { // Real CreateEntities() always returns []string, not nil, because // it supports partial writes return []string{}, entity.DbError{Err: fmt.Errorf("dupe"), Type: "duplicate-entity"} @@ -144,7 +145,7 @@ func TestPostEntityErrors(t *testing.T) { // Test that POST /entities/:type returns an error for any issue created := false store := mock.EntityStore{ - CreateEntitiesFunc: func(wo entity.WriteOp, entities []etre.Entity) ([]string, error) { + CreateEntitiesFunc: func(ctx context.Context, wo entity.WriteOp, entities []etre.Entity) ([]string, error) { created = true return []string{"id1"}, nil }, @@ -254,7 +255,7 @@ func TestPutEntityOK(t *testing.T) { var gotWO entity.WriteOp var gotQuery query.Query store := mock.EntityStore{ - UpdateEntitiesFunc: func(wo entity.WriteOp, q query.Query, patch etre.Entity) ([]etre.Entity, error) { + UpdateEntitiesFunc: func(ctx context.Context, wo entity.WriteOp, q query.Query, patch etre.Entity) ([]etre.Entity, error) { gotWO = wo gotQuery = q diff := []etre.Entity{ @@ -329,7 +330,7 @@ func TestPutEntityDuplicate(t *testing.T) { // Test that PUT /entities/:type/:id returns HTTP 403 Conflict on duplicate // which we simulate by returning what entity.Store would: store := mock.EntityStore{ - UpdateEntitiesFunc: func(wo entity.WriteOp, q query.Query, patch etre.Entity) ([]etre.Entity, error) { + UpdateEntitiesFunc: func(ctx context.Context, wo entity.WriteOp, q query.Query, patch etre.Entity) ([]etre.Entity, error) { return nil, entity.DbError{ Type: "duplicate-entity", // the key to making this happen EntityId: testEntityIds[0], @@ -380,7 +381,7 @@ func TestPutEntityNotFound(t *testing.T) { // Test that PUT /entities/:type/:id returns HTTP 404 when there's no entity // with the given id. In this case, the entity.Store returns an empty diff: store := mock.EntityStore{ - UpdateEntitiesFunc: func(wo entity.WriteOp, q query.Query, patch etre.Entity) ([]etre.Entity, error) { + UpdateEntitiesFunc: func(ctx context.Context, wo entity.WriteOp, q query.Query, patch etre.Entity) ([]etre.Entity, error) { return []etre.Entity{}, nil // no entities matched }, } @@ -419,7 +420,7 @@ func TestPutEntityErrors(t *testing.T) { // Test that PUT /entities/:type/:id returns errors unless all inputs are correct updated := false store := mock.EntityStore{ - UpdateEntitiesFunc: func(wo entity.WriteOp, q query.Query, patch etre.Entity) ([]etre.Entity, error) { + UpdateEntitiesFunc: func(ctx context.Context, wo entity.WriteOp, q query.Query, patch etre.Entity) ([]etre.Entity, error) { updated = true return []etre.Entity{{"_id": testEntityId0, "_type": entityType, "_rev": int64(0)}}, nil }, @@ -592,7 +593,7 @@ func TestDeleteEntityOK(t *testing.T) { var gotWO entity.WriteOp var gotQuery query.Query store := mock.EntityStore{ - DeleteEntitiesFunc: func(wo entity.WriteOp, q query.Query) ([]etre.Entity, error) { + DeleteEntitiesFunc: func(ctx context.Context, wo entity.WriteOp, q query.Query) ([]etre.Entity, error) { gotWO = wo gotQuery = q diff := []etre.Entity{ @@ -671,7 +672,7 @@ func TestDeleteLabel(t *testing.T) { var gotWO entity.WriteOp var gotLabel string store := mock.EntityStore{ - DeleteLabelFunc: func(wo entity.WriteOp, label string) (etre.Entity, error) { + DeleteLabelFunc: func(ctx context.Context, wo entity.WriteOp, label string) (etre.Entity, error) { gotWO = wo gotLabel = label return etre.Entity{"_id": testEntityId0, "_type": entityType, "_rev": int64(0), "foo": "oldVal"}, nil diff --git a/entity/store.go b/entity/store.go index 73fbd80..18e80ca 100644 --- a/entity/store.go +++ b/entity/store.go @@ -19,23 +19,20 @@ import ( // Store interface has methods needed to do CRUD operations on entities. type Store interface { - WithContext(context.Context) Store + ReadEntities(context.Context, string, query.Query, etre.QueryFilter) ([]etre.Entity, error) - ReadEntities(string, query.Query, etre.QueryFilter) ([]etre.Entity, error) + CreateEntities(context.Context, WriteOp, []etre.Entity) ([]string, error) - CreateEntities(WriteOp, []etre.Entity) ([]string, error) + UpdateEntities(context.Context, WriteOp, query.Query, etre.Entity) ([]etre.Entity, error) - UpdateEntities(WriteOp, query.Query, etre.Entity) ([]etre.Entity, error) + DeleteEntities(context.Context, WriteOp, query.Query) ([]etre.Entity, error) - DeleteEntities(WriteOp, query.Query) ([]etre.Entity, error) - - DeleteLabel(WriteOp, string) (etre.Entity, error) + DeleteLabel(context.Context, WriteOp, string) (etre.Entity, error) } type store struct { coll map[string]*mongo.Collection cdcs cdc.Store - ctx context.Context config config.EntityConfig } @@ -44,20 +41,14 @@ func NewStore(entities map[string]*mongo.Collection, cdcStore cdc.Store, cfg con return store{ coll: entities, cdcs: cdcStore, - ctx: context.Background(), config: cfg, } } -func (s store) WithContext(ctx context.Context) Store { - s.ctx = ctx - return s -} - // ReadEntities queries the db and returns a slice of Entity objects if // something is found, a nil slice if nothing is found, and an error if one // occurs. -func (s store) ReadEntities(entityType string, q query.Query, f etre.QueryFilter) ([]etre.Entity, error) { +func (s store) ReadEntities(ctx context.Context, entityType string, q query.Query, f etre.QueryFilter) ([]etre.Entity, error) { c, ok := s.coll[entityType] if !ok { panic("invalid entity type passed to ReadEntities: " + entityType) @@ -67,9 +58,9 @@ func (s store) ReadEntities(entityType string, q query.Query, f etre.QueryFilter // "es -u node.metacluster zone=pd" returns a list of unique metacluster names. // This is 10x faster than "es node.metacluster zone=pd | sort -u". if len(f.ReturnLabels) == 1 && f.Distinct { - values, err := c.Distinct(s.ctx, f.ReturnLabels[0], Filter(q)) + values, err := c.Distinct(ctx, f.ReturnLabels[0], Filter(q)) if err != nil { - return nil, s.dbError(err, "db-read-distinct") + return nil, s.dbError(ctx, err, "db-read-distinct") } entities := make([]etre.Entity, len(values)) for i, v := range values { @@ -94,13 +85,13 @@ func (s store) ReadEntities(entityType string, q query.Query, f etre.QueryFilter // Set batch size and projection opts := options.Find().SetProjection(p).SetBatchSize(int32(s.config.BatchSize)) - cursor, err := c.Find(s.ctx, Filter(q), opts) + cursor, err := c.Find(ctx, Filter(q), opts) if err != nil { - return nil, s.dbError(err, "db-query") + return nil, s.dbError(ctx, err, "db-query") } entities := []etre.Entity{} - if err := cursor.All(s.ctx, &entities); err != nil { - return nil, s.dbError(err, "db-read-cursor") + if err := cursor.All(ctx, &entities); err != nil { + return nil, s.dbError(ctx, err, "db-read-cursor") } return entities, nil } @@ -120,7 +111,7 @@ func (s store) ReadEntities(entityType string, q query.Query, f etre.QueryFilter // entities inserted. Since the entities were inserted in order (guranteed by // inserting one by one), caller should only return subset of entities that // failed to be inserted. -func (s store) CreateEntities(wo WriteOp, entities []etre.Entity) ([]string, error) { +func (s store) CreateEntities(ctx context.Context, wo WriteOp, entities []etre.Entity) ([]string, error) { c, ok := s.coll[wo.EntityType] if !ok { panic("invalid entity type passed to CreateEntities: " + wo.EntityType) @@ -137,9 +128,9 @@ func (s store) CreateEntities(wo WriteOp, entities []etre.Entity) ([]string, err entities[i]["_created"] = now entities[i]["_updated"] = now - res, err := c.InsertOne(s.ctx, entities[i]) + res, err := c.InsertOne(ctx, entities[i]) if err != nil { - return newIds, s.dbError(err, "db-insert") + return newIds, s.dbError(ctx, err, "db-insert") } id := res.InsertedID.(primitive.ObjectID) newIds = append(newIds, id.Hex()) @@ -152,7 +143,7 @@ func (s store) CreateEntities(wo WriteOp, entities []etre.Entity) ([]string, err old: nil, rev: int64(0), } - if err := s.cdcWrite(entities[i], wo, cp); err != nil { + if err := s.cdcWrite(ctx, entities[i], wo, cp); err != nil { return newIds, err } } @@ -174,18 +165,18 @@ func (s store) CreateEntities(wo WriteOp, entities []etre.Entity) ([]string, err // update := db.Entity{"y": "bar"} // // diffs, err := c.UpdateEntities(q, update) -func (s store) UpdateEntities(wo WriteOp, q query.Query, patch etre.Entity) ([]etre.Entity, error) { +func (s store) UpdateEntities(ctx context.Context, wo WriteOp, q query.Query, patch etre.Entity) ([]etre.Entity, error) { c, ok := s.coll[wo.EntityType] if !ok { panic("invalid entity type passed to UpdateEntities: " + wo.EntityType) } fopts := options.Find().SetProjection(bson.M{"_id": 1}) - cursor, err := c.Find(s.ctx, Filter(q), fopts) + cursor, err := c.Find(ctx, Filter(q), fopts) if err != nil { - return nil, s.dbError(err, "db-query") + return nil, s.dbError(ctx, err, "db-query") } - defer cursor.Close(s.ctx) + defer cursor.Close(ctx) // diffs is a slice made up of a diff for each doc updated diffs := []etre.Entity{} @@ -205,19 +196,19 @@ func (s store) UpdateEntities(wo WriteOp, q query.Query, patch etre.Entity) ([]e opts := options.FindOneAndUpdate().SetProjection(p) nextId := map[string]primitive.ObjectID{} - for cursor.Next(s.ctx) { + for cursor.Next(ctx) { if err := cursor.Decode(&nextId); err != nil { - return diffs, s.dbError(err, "db-cursor-decode") + return diffs, s.dbError(ctx, err, "db-cursor-decode") } uq, _ := query.Translate("_id=" + nextId["_id"].Hex()) var orig etre.Entity - err := c.FindOneAndUpdate(s.ctx, Filter(uq), updates, opts).Decode(&orig) + err := c.FindOneAndUpdate(ctx, Filter(uq), updates, opts).Decode(&orig) if err != nil { if err == mongo.ErrNoDocuments { break } - return diffs, s.dbError(err, "db-update") + return diffs, s.dbError(ctx, err, "db-update") } diffs = append(diffs, orig) @@ -236,13 +227,13 @@ func (s store) UpdateEntities(wo WriteOp, q query.Query, patch etre.Entity) ([]e old: &old, new: &patch, } - if err := s.cdcWrite(patch, wo, cp); err != nil { + if err := s.cdcWrite(ctx, patch, wo, cp); err != nil { return diffs, err } } if err := cursor.Err(); err != nil { - return diffs, s.dbError(err, "db-cursor-next") + return diffs, s.dbError(ctx, err, "db-cursor-next") } return diffs, nil @@ -256,7 +247,7 @@ func (s store) UpdateEntities(wo WriteOp, q query.Query, patch etre.Entity) ([]e // Returns a slice of successfully deleted entities an error if there is one. // For example, if 4 entities were supposed to be deleted and 3 are ok and the // 4th fails, a slice with 3 deleted entities and an error will be returned. -func (s store) DeleteEntities(wo WriteOp, q query.Query) ([]etre.Entity, error) { +func (s store) DeleteEntities(ctx context.Context, wo WriteOp, q query.Query) ([]etre.Entity, error) { c, ok := s.coll[wo.EntityType] if !ok { panic("invalid entity type passed to DeleteEntities: " + wo.EntityType) @@ -265,12 +256,12 @@ func (s store) DeleteEntities(wo WriteOp, q query.Query) ([]etre.Entity, error) deleted := []etre.Entity{} for { var old etre.Entity - err := c.FindOneAndDelete(s.ctx, Filter(q)).Decode(&old) + err := c.FindOneAndDelete(ctx, Filter(q)).Decode(&old) if err != nil { if err == mongo.ErrNoDocuments { break } - return deleted, s.dbError(err, "db-delete") + return deleted, s.dbError(ctx, err, "db-delete") } deleted = append(deleted, old) ce := cdcPartial{ @@ -280,7 +271,7 @@ func (s store) DeleteEntities(wo WriteOp, q query.Query) ([]etre.Entity, error) new: nil, rev: old.Rev() + 1, } - if err := s.cdcWrite(old, wo, ce); err != nil { + if err := s.cdcWrite(ctx, old, wo, ce); err != nil { return deleted, err } } @@ -289,7 +280,7 @@ func (s store) DeleteEntities(wo WriteOp, q query.Query) ([]etre.Entity, error) } // DeleteLabel deletes a label from an entity. -func (s store) DeleteLabel(wo WriteOp, label string) (etre.Entity, error) { +func (s store) DeleteLabel(ctx context.Context, wo WriteOp, label string) (etre.Entity, error) { c, ok := s.coll[wo.EntityType] if !ok { panic("invalid entity type passed to DeleteLabel: " + wo.EntityType) @@ -305,9 +296,9 @@ func (s store) DeleteLabel(wo WriteOp, label string) (etre.Entity, error) { SetProjection(bson.M{"_id": 1, "_type": 1, "_rev": 1, label: 1}). SetReturnDocument(options.Before) var old etre.Entity - err := c.FindOneAndUpdate(s.ctx, filter, update, opts).Decode(&old) + err := c.FindOneAndUpdate(ctx, filter, update, opts).Decode(&old) if err != nil { - return nil, s.dbError(err, "db-update") + return nil, s.dbError(ctx, err, "db-update") } // Make the new Entity by copying the old and deleting the label @@ -336,15 +327,15 @@ func (s store) DeleteLabel(wo WriteOp, label string) (etre.Entity, error) { old: &old, rev: old.Rev() + 1, } - if err := s.cdcWrite(etre.Entity{}, wo, cp); err != nil { + if err := s.cdcWrite(ctx, etre.Entity{}, wo, cp); err != nil { return old, err } return old, nil } -func (s store) dbError(err error, errType string) error { - if ctxErr := s.ctx.Err(); ctxErr != nil { +func (s store) dbError(ctx context.Context, err error, errType string) error { + if ctxErr := ctx.Err(); ctxErr != nil { return DbError{Err: ctxErr, Type: errType} } if dupe := IsDupeKeyError(err); dupe != nil { @@ -370,7 +361,7 @@ type cdcPartial struct { rev int64 } -func (s store) cdcWrite(e etre.Entity, wo WriteOp, cp cdcPartial) error { +func (s store) cdcWrite(ctx context.Context, e etre.Entity, wo WriteOp, cp cdcPartial) error { // set op from entity or wo, in that order. set := e.Set() if set.Size == 0 && wo.SetSize > 0 { @@ -393,7 +384,7 @@ func (s store) cdcWrite(e etre.Entity, wo WriteOp, cp cdcPartial) error { SetOp: set.Op, SetSize: set.Size, } - if err := s.cdcs.Write(s.ctx, event); err != nil { + if err := s.cdcs.Write(ctx, event); err != nil { return DbError{Err: err, Type: "cdc-write", EntityId: cp.id.Hex()} } return nil diff --git a/entity/store_test.go b/entity/store_test.go index a2cf7d9..1b1a0a4 100644 --- a/entity/store_test.go +++ b/entity/store_test.go @@ -116,7 +116,7 @@ func TestReadEntitiesWithAllOperators(t *testing.T) { q, err := query.Translate(qs) require.NoError(t, err) - actual, err := store.ReadEntities(entityType, q, etre.QueryFilter{}) + actual, err := store.ReadEntities(context.Background(), entityType, q, etre.QueryFilter{}) require.NoError(t, err) assert.Equal(t, expect, actual) } @@ -158,7 +158,7 @@ func TestReadEntitiesMatching(t *testing.T) { q, err := query.Translate(rt.query) require.NoError(t, err) - got, err := store.ReadEntities(entityType, q, etre.QueryFilter{}) + got, err := store.ReadEntities(context.Background(), entityType, q, etre.QueryFilter{}) require.NoError(t, err) assert.Equal(t, rt.expect, got) } @@ -176,7 +176,7 @@ func TestReadEntitiesFilterDistinct(t *testing.T) { ReturnLabels: []string{"y"}, // only works with 1 return label Distinct: true, } - got, err := store.ReadEntities(entityType, q, f) + got, err := store.ReadEntities(context.Background(), entityType, q, f) require.NoError(t, err) expect := []etre.Entity{ @@ -202,7 +202,7 @@ func TestReadEntitiesFilterReturnLabels(t *testing.T) { {"x": int64(4)}, {"x": int64(6)}, } - got, err := store.ReadEntities(entityType, q, f) + got, err := store.ReadEntities(context.Background(), entityType, q, f) require.NoError(t, err) assert.Equal(t, expect, got) } @@ -212,7 +212,7 @@ func TestReadEntitiesFilterReturnMetalabels(t *testing.T) { q, err := query.Translate("y=a") require.NoError(t, err) - actual, err := store.ReadEntities(entityType, q, etre.QueryFilter{ReturnLabels: []string{"_id", "_type", "_rev", "y", "_created", "_updated"}}) + actual, err := store.ReadEntities(context.Background(), entityType, q, etre.QueryFilter{ReturnLabels: []string{"_id", "_type", "_rev", "y", "_created", "_updated"}}) require.NoError(t, err) expect := []etre.Entity{ @@ -243,7 +243,7 @@ func TestCreateEntitiesMultiple(t *testing.T) { etre.Entity{"x": 8}, etre.Entity{"x": 9, "_setId": "343", "_setOp": "something", "_setSize": 1}, } - ids, err := store.CreateEntities(wo, testData) + ids, err := store.CreateEntities(context.Background(), wo, testData) require.NoError(t, err) assert.Len(t, ids, len(testData)) @@ -314,7 +314,7 @@ func TestCreateEntitiesMultiplePartialSuccess(t *testing.T) { etre.Entity{"x": 6}, // dupe etre.Entity{"x": 7}, // would be ok but blocked by dupe } - ids, err := store.CreateEntities(wo, testData) + ids, err := store.CreateEntities(context.Background(), wo, testData) require.Error(t, err) dberr, ok := err.(entity.DbError) require.True(t, ok, "got error type %#v, expected entity.DbError", err) @@ -371,7 +371,7 @@ func TestUpdateEntities(t *testing.T) { SetId: "111", SetSize: 1, } - gotDiffs, err := store.UpdateEntities(wo1, q, patch) + gotDiffs, err := store.UpdateEntities(context.Background(), wo1, q, patch) require.NoError(t, err) assert.Len(t, gotDiffs, 1) expectDiffs := []etre.Entity{ @@ -398,7 +398,7 @@ func TestUpdateEntities(t *testing.T) { SetId: "222", SetSize: 1, } - gotDiffs, err = store.UpdateEntities(wo2, q, patch) + gotDiffs, err = store.UpdateEntities(context.Background(), wo2, q, patch) require.NoError(t, err) assert.Len(t, gotDiffs, 2) expectDiffs = []etre.Entity{ @@ -501,7 +501,7 @@ func TestUpdateEntitiesById(t *testing.T) { SetId: "111", SetSize: 1, } - gotDiffs, err := store.UpdateEntities(wo1, q, patch) + gotDiffs, err := store.UpdateEntities(context.Background(), wo1, q, patch) require.NoError(t, err) expectDiffs := []etre.Entity{ { @@ -527,7 +527,7 @@ func TestUpdateEntitiesById(t *testing.T) { SetId: "222", SetSize: 1, } - gotDiffs, err = store.UpdateEntities(wo2, q, patch) + gotDiffs, err = store.UpdateEntities(context.Background(), wo2, q, patch) require.NoError(t, err) expectDiffs = []etre.Entity{ { @@ -623,7 +623,7 @@ func TestUpdateEntitiesDuplicate(t *testing.T) { EntityType: entityType, Caller: username, } - gotDiffs, err := store.UpdateEntities(wo1, q, patch) + gotDiffs, err := store.UpdateEntities(context.Background(), wo1, q, patch) require.Error(t, err) dberr, ok := err.(entity.DbError) require.True(t, ok, "got error type %#v, expected entity.DbError", err) @@ -650,7 +650,7 @@ func TestDeleteEntities(t *testing.T) { q, err := query.Translate("y == a") require.NoError(t, err) - gotOld, err := store.DeleteEntities(wo, q) + gotOld, err := store.DeleteEntities(context.Background(), wo, q) require.NoError(t, err) assert.Equal(t, testNodes[:1], gotOld) @@ -658,7 +658,7 @@ func TestDeleteEntities(t *testing.T) { q, err = query.Translate("y == b") require.NoError(t, err) - gotOld, err = store.DeleteEntities(wo, q) + gotOld, err = store.DeleteEntities(context.Background(), wo, q) require.NoError(t, err) assert.Equal(t, testNodes[1:], gotOld) @@ -717,7 +717,7 @@ func TestDeleteLabel(t *testing.T) { EntityId: testNodes[0]["_id"].(primitive.ObjectID).Hex(), Caller: username, } - gotOld, err := store.DeleteLabel(wo, "foo") + gotOld, err := store.DeleteLabel(context.Background(), wo, "foo") require.NoError(t, err) expectOld := etre.Entity{ @@ -730,7 +730,7 @@ func TestDeleteLabel(t *testing.T) { // The foo label should no longer be set on the entity q, _ := query.Translate("y=a") - gotNew, err := store.ReadEntities(entityType, q, etre.QueryFilter{}) + gotNew, err := store.ReadEntities(context.Background(), entityType, q, etre.QueryFilter{}) require.NoError(t, err) e := etre.Entity{} diff --git a/entity/v09_test.go b/entity/v09_test.go index 644f8a9..76e1f3f 100644 --- a/entity/v09_test.go +++ b/entity/v09_test.go @@ -94,7 +94,7 @@ func TestV09CreateEntitiesMultiple(t *testing.T) { {"x": "e"}, {"x": "f", "_setId": "343", "_setOp": "something", "_setSize": 1}, } - ids, err := store.CreateEntities(wo, testData) + ids, err := store.CreateEntities(context.Background(), wo, testData) require.NoError(t, err) assert.Len(t, ids, len(testData)) @@ -167,7 +167,7 @@ func TestV09UpdateEntities(t *testing.T) { SetId: "111", SetSize: 1, } - gotDiffs, err := store.UpdateEntities(wo1, q, patch) + gotDiffs, err := store.UpdateEntities(context.Background(), wo1, q, patch) require.NoError(t, err) expectDiffs := []etre.Entity{ { @@ -218,7 +218,7 @@ func TestV09DeleteEntities(t *testing.T) { q, err := query.Translate("x == a") require.NoError(t, err) - gotOld, err := store.DeleteEntities(wo, q) + gotOld, err := store.DeleteEntities(context.Background(), wo, q) require.NoError(t, err) assert.Equal(t, v09testNodes_int32[:1], gotOld) @@ -226,7 +226,7 @@ func TestV09DeleteEntities(t *testing.T) { q, err = query.Translate("x in (b,c)") require.NoError(t, err) - gotOld, err = store.DeleteEntities(wo, q) + gotOld, err = store.DeleteEntities(context.Background(), wo, q) require.NoError(t, err) assert.Equal(t, v09testNodes_int32[1:], gotOld) @@ -281,7 +281,7 @@ func TestV09DeleteLabel(t *testing.T) { EntityId: v09testNodes[0]["_id"].(primitive.ObjectID).Hex(), Caller: username, } - gotOld, err := store.DeleteLabel(wo, "y") + gotOld, err := store.DeleteLabel(context.Background(), wo, "y") require.NoError(t, err) expectOld := etre.Entity{ @@ -294,7 +294,7 @@ func TestV09DeleteLabel(t *testing.T) { // The foo label should no longer be set on the entity q, _ := query.Translate("x=a") - gotNew, err := store.ReadEntities(entityType, q, etre.QueryFilter{}) + gotNew, err := store.ReadEntities(context.Background(), entityType, q, etre.QueryFilter{}) require.NoError(t, err) e := etre.Entity{} diff --git a/test/mock/entity.go b/test/mock/entity.go index dfe666c..edf39bd 100644 --- a/test/mock/entity.go +++ b/test/mock/entity.go @@ -11,60 +11,52 @@ import ( ) type EntityStore struct { - WithContextFunc func(context.Context) entity.Store - ReadEntitiesFunc func(string, query.Query, etre.QueryFilter) ([]etre.Entity, error) - DeleteEntityLabelFunc func(entity.WriteOp, string) (etre.Entity, error) - CreateEntitiesFunc func(entity.WriteOp, []etre.Entity) ([]string, error) - UpdateEntitiesFunc func(entity.WriteOp, query.Query, etre.Entity) ([]etre.Entity, error) - DeleteEntitiesFunc func(entity.WriteOp, query.Query) ([]etre.Entity, error) - DeleteLabelFunc func(entity.WriteOp, string) (etre.Entity, error) + ReadEntitiesFunc func(context.Context, string, query.Query, etre.QueryFilter) ([]etre.Entity, error) + DeleteEntityLabelFunc func(context.Context, entity.WriteOp, string) (etre.Entity, error) + CreateEntitiesFunc func(context.Context, entity.WriteOp, []etre.Entity) ([]string, error) + UpdateEntitiesFunc func(context.Context, entity.WriteOp, query.Query, etre.Entity) ([]etre.Entity, error) + DeleteEntitiesFunc func(context.Context, entity.WriteOp, query.Query) ([]etre.Entity, error) + DeleteLabelFunc func(context.Context, entity.WriteOp, string) (etre.Entity, error) } -func (s EntityStore) WithContext(ctx context.Context) entity.Store { - if s.WithContextFunc != nil { - return s.WithContextFunc(ctx) - } - return s -} - -func (s EntityStore) DeleteEntityLabel(wo entity.WriteOp, label string) (etre.Entity, error) { +func (s EntityStore) DeleteEntityLabel(ctx context.Context, wo entity.WriteOp, label string) (etre.Entity, error) { if s.DeleteEntityLabelFunc != nil { - return s.DeleteEntityLabelFunc(wo, label) + return s.DeleteEntityLabelFunc(ctx, wo, label) } return nil, nil } -func (s EntityStore) CreateEntities(wo entity.WriteOp, entities []etre.Entity) ([]string, error) { +func (s EntityStore) CreateEntities(ctx context.Context, wo entity.WriteOp, entities []etre.Entity) ([]string, error) { if s.CreateEntitiesFunc != nil { - return s.CreateEntitiesFunc(wo, entities) + return s.CreateEntitiesFunc(ctx, wo, entities) } return nil, nil } -func (s EntityStore) ReadEntities(entityType string, q query.Query, f etre.QueryFilter) ([]etre.Entity, error) { +func (s EntityStore) ReadEntities(ctx context.Context, entityType string, q query.Query, f etre.QueryFilter) ([]etre.Entity, error) { if s.ReadEntitiesFunc != nil { - return s.ReadEntitiesFunc(entityType, q, f) + return s.ReadEntitiesFunc(ctx, entityType, q, f) } return nil, nil } -func (s EntityStore) UpdateEntities(wo entity.WriteOp, q query.Query, u etre.Entity) ([]etre.Entity, error) { +func (s EntityStore) UpdateEntities(ctx context.Context, wo entity.WriteOp, q query.Query, u etre.Entity) ([]etre.Entity, error) { if s.UpdateEntitiesFunc != nil { - return s.UpdateEntitiesFunc(wo, q, u) + return s.UpdateEntitiesFunc(ctx, wo, q, u) } return nil, nil } -func (s EntityStore) DeleteEntities(wo entity.WriteOp, q query.Query) ([]etre.Entity, error) { +func (s EntityStore) DeleteEntities(ctx context.Context, wo entity.WriteOp, q query.Query) ([]etre.Entity, error) { if s.DeleteEntitiesFunc != nil { - return s.DeleteEntitiesFunc(wo, q) + return s.DeleteEntitiesFunc(ctx, wo, q) } return nil, nil } -func (s EntityStore) DeleteLabel(wo entity.WriteOp, label string) (etre.Entity, error) { +func (s EntityStore) DeleteLabel(ctx context.Context, wo entity.WriteOp, label string) (etre.Entity, error) { if s.DeleteLabelFunc != nil { - return s.DeleteLabelFunc(wo, label) + return s.DeleteLabelFunc(ctx, wo, label) } return etre.Entity{}, nil }