diff --git a/pkg/cmd/apply/apply.go b/pkg/cmd/apply/apply.go index c8502ca..066a63c 100644 --- a/pkg/cmd/apply/apply.go +++ b/pkg/cmd/apply/apply.go @@ -1,9 +1,7 @@ package apply import ( - "bytes" "context" - "encoding/json" "fmt" "io" "os" @@ -13,11 +11,8 @@ import ( "github.com/schollz/progressbar/v3" "github.com/spf13/cobra" "google.golang.org/grpc" - "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/reflect/protoregistry" - "github.com/Azure/AKSFlexNode/components/api" "github.com/Azure/AKSFlexNode/components/services/actions" "github.com/Azure/AKSFlexNode/components/services/inmem" ) @@ -44,7 +39,11 @@ var Command = &cobra.Command{ return err } - return apply(cmd.Context(), input, flagNoPrettyUI) + parsed, err := parseActions(input) + if err != nil { + return err + } + return apply(cmd.Context(), parsed, flagNoPrettyUI) }, SilenceUsage: true, } @@ -76,37 +75,13 @@ type stepResult struct { err error } -func apply(ctx context.Context, input []byte, noPrettyUI bool) error { +func apply(ctx context.Context, parsed []parsedAction, noPrettyUI bool) error { conn, err := inmem.NewConnection() if err != nil { return err } defer conn.Close() //nolint:errcheck // close connection - tok, err := json.NewDecoder(bytes.NewBuffer(input)).Token() - if err != nil { - return err - } - - var bs []json.RawMessage - if tok == json.Delim('[') { - if err := json.Unmarshal(input, &bs); err != nil { - return err - } - } else { - bs = append(bs, input) - } - - // Pre-parse all actions so we know the total count and names up front. - parsed := make([]parsedAction, 0, len(bs)) - for _, b := range bs { - pa, err := parseAction(b) - if err != nil { - return err - } - parsed = append(parsed, pa) - } - if noPrettyUI { return applyPlain(ctx, conn, parsed) } @@ -243,31 +218,3 @@ func formatDuration(d time.Duration) string { return fmt.Sprintf("%.1fs", d.Seconds()) } } - -func parseAction(b []byte) (parsedAction, error) { - base := &api.Base{} - if err := (protojson.UnmarshalOptions{DiscardUnknown: true}).Unmarshal(b, base); err != nil { - return parsedAction{}, err - } - - actionType := base.GetMetadata().GetType() - actionName := base.GetMetadata().GetName() - - mt, err := protoregistry.GlobalTypes.FindMessageByURL(actionType) - if err != nil { - return parsedAction{}, fmt.Errorf("lookup action type %q: %w", actionType, err) - } - - m := mt.New().Interface() - if err := protojson.Unmarshal(b, m); err != nil { - return parsedAction{}, fmt.Errorf("unmarshal action %q: %w", actionType, err) - } - - // Use the action name if available, otherwise fall back to the type URL. - name := actionName - if name == "" { - name = actionType - } - - return parsedAction{name: name, message: m}, nil -} diff --git a/pkg/cmd/apply/parse.go b/pkg/cmd/apply/parse.go new file mode 100644 index 0000000..dcb961d --- /dev/null +++ b/pkg/cmd/apply/parse.go @@ -0,0 +1,125 @@ +package apply + +import ( + "bytes" + "encoding/json" + "fmt" + + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoregistry" + + "github.com/Azure/AKSFlexNode/components/api" +) + +// isJSONContent reports whether input appears to be a JSON object or array. +// It inspects the first non-whitespace byte: JSON objects begin with '{' and +// JSON arrays begin with '['; binary protobuf never starts with either byte. +func isJSONContent(input []byte) bool { + for _, b := range input { + if b == ' ' || b == '\t' || b == '\r' || b == '\n' { + continue + } + return b == '{' || b == '[' + } + return false +} + +// parseActions detects the input format and returns the pre-parsed actions. +func parseActions(input []byte) ([]parsedAction, error) { + if isJSONContent(input) { + return parseActionFromJSON(input) + } + + pa, err := parseActionFromProto(input) + if err != nil { + return nil, err + } + return []parsedAction{pa}, nil +} + +// parseActionFromJSON deserializes one or more JSON-encoded actions. The input +// may be a single JSON object or a JSON array of objects. +func parseActionFromJSON(input []byte) ([]parsedAction, error) { + tok, err := json.NewDecoder(bytes.NewBuffer(input)).Token() + if err != nil { + return nil, err + } + + var bs []json.RawMessage + if tok == json.Delim('[') { + if err := json.Unmarshal(input, &bs); err != nil { + return nil, err + } + } else { + bs = append(bs, input) + } + + // Pre-parse all actions so we know the total count and names up front. + parsed := make([]parsedAction, 0, len(bs)) + for _, b := range bs { + pa, err := parseAction(b) + if err != nil { + return nil, err + } + parsed = append(parsed, pa) + } + return parsed, nil +} + +func parseAction(b []byte) (parsedAction, error) { + base := &api.Base{} + if err := (protojson.UnmarshalOptions{DiscardUnknown: true}).Unmarshal(b, base); err != nil { + return parsedAction{}, err + } + + actionType := base.GetMetadata().GetType() + actionName := base.GetMetadata().GetName() + + mt, err := protoregistry.GlobalTypes.FindMessageByURL(actionType) + if err != nil { + return parsedAction{}, fmt.Errorf("lookup action type %q: %w", actionType, err) + } + + m := mt.New().Interface() + if err := protojson.Unmarshal(b, m); err != nil { + return parsedAction{}, fmt.Errorf("unmarshal action %q: %w", actionType, err) + } + + // Use the action name if available, otherwise fall back to the type URL. + name := actionName + if name == "" { + name = actionType + } + + return parsedAction{name: name, message: m}, nil +} + +// parseActionFromProto deserializes a single binary protobuf-encoded action. +func parseActionFromProto(b []byte) (parsedAction, error) { + base := &api.Base{} + if err := proto.Unmarshal(b, base); err != nil { + return parsedAction{}, err + } + + actionType := base.GetMetadata().GetType() + actionName := base.GetMetadata().GetName() + + mt, err := protoregistry.GlobalTypes.FindMessageByURL(actionType) + if err != nil { + return parsedAction{}, fmt.Errorf("lookup action type %q: %w", actionType, err) + } + + m := mt.New().Interface() + if err := proto.Unmarshal(b, m); err != nil { + return parsedAction{}, fmt.Errorf("unmarshal action %q: %w", actionType, err) + } + + // Use the action name if available, otherwise fall back to the type URL. + name := actionName + if name == "" { + name = actionType + } + + return parsedAction{name: name, message: m}, nil +} diff --git a/pkg/cmd/apply/parse_test.go b/pkg/cmd/apply/parse_test.go new file mode 100644 index 0000000..9dd4b0d --- /dev/null +++ b/pkg/cmd/apply/parse_test.go @@ -0,0 +1,121 @@ +package apply + +import ( + "testing" + + "google.golang.org/protobuf/proto" + + "github.com/Azure/AKSFlexNode/components/api" + _ "github.com/Azure/AKSFlexNode/components/linux" // register linux action types +) + +func TestIsJSONContent(t *testing.T) { + tests := []struct { + name string + input []byte + want bool + }{ + {name: "json object", input: []byte(`{"metadata":{}}`), want: true}, + {name: "json array", input: []byte(`[{"metadata":{}}]`), want: true}, + {name: "json object with leading whitespace", input: []byte(" \t\n{\"metadata\":{}}"), want: true}, + {name: "json array with leading whitespace", input: []byte("\n [{\"metadata\":{}}]"), want: true}, + {name: "binary proto (non-json first byte)", input: []byte{0x0a, 0x00}, want: false}, + {name: "empty input", input: []byte{}, want: false}, + {name: "whitespace only", input: []byte(" "), want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isJSONContent(tt.input) + if got != tt.want { + t.Errorf("isJSONContent(%q) = %v, want %v", tt.input, got, tt.want) + } + }) + } +} + +func TestParseActionFromProto(t *testing.T) { + // Build a binary-encoded Base message that carries the ConfigureBaseOS type URL. + base := &api.Base{} + base.SetMetadata(&api.Metadata{}) + base.GetMetadata().SetType("aks.flex.components.linux.ConfigureBaseOS") + base.GetMetadata().SetName("test-action") + + b, err := proto.Marshal(base) + if err != nil { + t.Fatalf("proto.Marshal: %v", err) + } + + pa, err := parseActionFromProto(b) + if err != nil { + t.Fatalf("parseActionFromProto: %v", err) + } + + if pa.name != "test-action" { + t.Errorf("name = %q, want %q", pa.name, "test-action") + } + if pa.message == nil { + t.Error("message is nil") + } +} + +func TestParseActionFromProto_FallbackName(t *testing.T) { + // When no name is set the type URL is used as the display name. + base := &api.Base{} + base.SetMetadata(&api.Metadata{}) + base.GetMetadata().SetType("aks.flex.components.linux.ConfigureBaseOS") + + b, err := proto.Marshal(base) + if err != nil { + t.Fatalf("proto.Marshal: %v", err) + } + + pa, err := parseActionFromProto(b) + if err != nil { + t.Fatalf("parseActionFromProto: %v", err) + } + + want := "aks.flex.components.linux.ConfigureBaseOS" + if pa.name != want { + t.Errorf("name = %q, want %q", pa.name, want) + } +} + +func TestParseActionFromProto_UnknownType(t *testing.T) { + base := &api.Base{} + base.SetMetadata(&api.Metadata{}) + base.GetMetadata().SetType("does.not.Exist") + + b, err := proto.Marshal(base) + if err != nil { + t.Fatalf("proto.Marshal: %v", err) + } + + if _, err := parseActionFromProto(b); err == nil { + t.Error("expected error for unknown type, got nil") + } +} + +func TestParseActions_BinaryProtoRoundTrip(t *testing.T) { + // A real serialized proto message must be routed to the binary proto parser. + base := &api.Base{} + base.SetMetadata(&api.Metadata{}) + base.GetMetadata().SetType("aks.flex.components.linux.ConfigureBaseOS") + base.GetMetadata().SetName("round-trip") + + b, err := proto.Marshal(base) + if err != nil { + t.Fatalf("proto.Marshal: %v", err) + } + + parsed, err := parseActions(b) + if err != nil { + t.Fatalf("parseActions: %v", err) + } + if len(parsed) != 1 { + t.Fatalf("len(parsed) = %d, want 1", len(parsed)) + } + if parsed[0].name != "round-trip" { + t.Errorf("name = %q, want %q", parsed[0].name, "round-trip") + } +}