Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 6 additions & 59 deletions pkg/cmd/apply/apply.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package apply

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"os"
Expand All @@ -13,11 +11,8 @@ import (
"github.com/schollz/progressbar/v3"
"github.com/spf13/cobra"
"google.golang.org/grpc"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoregistry"

"github.com/Azure/AKSFlexNode/components/api"
"github.com/Azure/AKSFlexNode/components/services/actions"
"github.com/Azure/AKSFlexNode/components/services/inmem"
)
Expand All @@ -44,7 +39,11 @@ var Command = &cobra.Command{
return err
}

return apply(cmd.Context(), input, flagNoPrettyUI)
parsed, err := parseActions(input)
if err != nil {
return err
}
return apply(cmd.Context(), parsed, flagNoPrettyUI)
},
SilenceUsage: true,
}
Expand Down Expand Up @@ -76,37 +75,13 @@ type stepResult struct {
err error
}

func apply(ctx context.Context, input []byte, noPrettyUI bool) error {
func apply(ctx context.Context, parsed []parsedAction, noPrettyUI bool) error {
conn, err := inmem.NewConnection()
if err != nil {
return err
}
defer conn.Close() //nolint:errcheck // close connection

tok, err := json.NewDecoder(bytes.NewBuffer(input)).Token()
if err != nil {
return err
}

var bs []json.RawMessage
if tok == json.Delim('[') {
if err := json.Unmarshal(input, &bs); err != nil {
return err
}
} else {
bs = append(bs, input)
}

// Pre-parse all actions so we know the total count and names up front.
parsed := make([]parsedAction, 0, len(bs))
for _, b := range bs {
pa, err := parseAction(b)
if err != nil {
return err
}
parsed = append(parsed, pa)
}

if noPrettyUI {
return applyPlain(ctx, conn, parsed)
}
Expand Down Expand Up @@ -243,31 +218,3 @@ func formatDuration(d time.Duration) string {
return fmt.Sprintf("%.1fs", d.Seconds())
}
}

func parseAction(b []byte) (parsedAction, error) {
base := &api.Base{}
if err := (protojson.UnmarshalOptions{DiscardUnknown: true}).Unmarshal(b, base); err != nil {
return parsedAction{}, err
}

actionType := base.GetMetadata().GetType()
actionName := base.GetMetadata().GetName()

mt, err := protoregistry.GlobalTypes.FindMessageByURL(actionType)
if err != nil {
return parsedAction{}, fmt.Errorf("lookup action type %q: %w", actionType, err)
}

m := mt.New().Interface()
if err := protojson.Unmarshal(b, m); err != nil {
return parsedAction{}, fmt.Errorf("unmarshal action %q: %w", actionType, err)
}

// Use the action name if available, otherwise fall back to the type URL.
name := actionName
if name == "" {
name = actionType
}

return parsedAction{name: name, message: m}, nil
}
125 changes: 125 additions & 0 deletions pkg/cmd/apply/parse.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
package apply

import (
"bytes"
"encoding/json"
"fmt"

"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoregistry"

"github.com/Azure/AKSFlexNode/components/api"
)

// isJSONContent reports whether input appears to be a JSON object or array.
// It inspects the first non-whitespace byte: JSON objects begin with '{' and
// JSON arrays begin with '['; binary protobuf never starts with either byte.
func isJSONContent(input []byte) bool {
for _, b := range input {
if b == ' ' || b == '\t' || b == '\r' || b == '\n' {
continue
}
return b == '{' || b == '['
}
return false
}

// parseActions detects the input format and returns the pre-parsed actions.
func parseActions(input []byte) ([]parsedAction, error) {
if isJSONContent(input) {
return parseActionFromJSON(input)
}

pa, err := parseActionFromProto(input)
if err != nil {
return nil, err
}
return []parsedAction{pa}, nil
}

// parseActionFromJSON deserializes one or more JSON-encoded actions. The input
// may be a single JSON object or a JSON array of objects.
func parseActionFromJSON(input []byte) ([]parsedAction, error) {
tok, err := json.NewDecoder(bytes.NewBuffer(input)).Token()
if err != nil {
return nil, err
}

var bs []json.RawMessage
if tok == json.Delim('[') {
if err := json.Unmarshal(input, &bs); err != nil {
return nil, err
}
} else {
bs = append(bs, input)
}

// Pre-parse all actions so we know the total count and names up front.
parsed := make([]parsedAction, 0, len(bs))
for _, b := range bs {
pa, err := parseAction(b)
if err != nil {
return nil, err
}
parsed = append(parsed, pa)
}
return parsed, nil
}

func parseAction(b []byte) (parsedAction, error) {
base := &api.Base{}
if err := (protojson.UnmarshalOptions{DiscardUnknown: true}).Unmarshal(b, base); err != nil {
return parsedAction{}, err
}

actionType := base.GetMetadata().GetType()
actionName := base.GetMetadata().GetName()

mt, err := protoregistry.GlobalTypes.FindMessageByURL(actionType)
if err != nil {
return parsedAction{}, fmt.Errorf("lookup action type %q: %w", actionType, err)
}

m := mt.New().Interface()
if err := protojson.Unmarshal(b, m); err != nil {
return parsedAction{}, fmt.Errorf("unmarshal action %q: %w", actionType, err)
}

// Use the action name if available, otherwise fall back to the type URL.
name := actionName
if name == "" {
name = actionType
}

return parsedAction{name: name, message: m}, nil
}

// parseActionFromProto deserializes a single binary protobuf-encoded action.
func parseActionFromProto(b []byte) (parsedAction, error) {
base := &api.Base{}
if err := proto.Unmarshal(b, base); err != nil {
return parsedAction{}, err
}

actionType := base.GetMetadata().GetType()
actionName := base.GetMetadata().GetName()

mt, err := protoregistry.GlobalTypes.FindMessageByURL(actionType)
if err != nil {
return parsedAction{}, fmt.Errorf("lookup action type %q: %w", actionType, err)
}

m := mt.New().Interface()
if err := proto.Unmarshal(b, m); err != nil {
return parsedAction{}, fmt.Errorf("unmarshal action %q: %w", actionType, err)
}

// Use the action name if available, otherwise fall back to the type URL.
name := actionName
if name == "" {
name = actionType
}

return parsedAction{name: name, message: m}, nil
}
121 changes: 121 additions & 0 deletions pkg/cmd/apply/parse_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package apply

import (
"testing"

"google.golang.org/protobuf/proto"

"github.com/Azure/AKSFlexNode/components/api"
_ "github.com/Azure/AKSFlexNode/components/linux" // register linux action types
)

func TestIsJSONContent(t *testing.T) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename this file to parse_test.go

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 494e582. apply_test.go has been renamed to parse_test.go.

tests := []struct {
name string
input []byte
want bool
}{
{name: "json object", input: []byte(`{"metadata":{}}`), want: true},
{name: "json array", input: []byte(`[{"metadata":{}}]`), want: true},
{name: "json object with leading whitespace", input: []byte(" \t\n{\"metadata\":{}}"), want: true},
{name: "json array with leading whitespace", input: []byte("\n [{\"metadata\":{}}]"), want: true},
{name: "binary proto (non-json first byte)", input: []byte{0x0a, 0x00}, want: false},
{name: "empty input", input: []byte{}, want: false},
{name: "whitespace only", input: []byte(" "), want: false},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isJSONContent(tt.input)
if got != tt.want {
t.Errorf("isJSONContent(%q) = %v, want %v", tt.input, got, tt.want)
}
})
}
}

func TestParseActionFromProto(t *testing.T) {
// Build a binary-encoded Base message that carries the ConfigureBaseOS type URL.
base := &api.Base{}
base.SetMetadata(&api.Metadata{})
base.GetMetadata().SetType("aks.flex.components.linux.ConfigureBaseOS")
base.GetMetadata().SetName("test-action")

b, err := proto.Marshal(base)
if err != nil {
t.Fatalf("proto.Marshal: %v", err)
}

pa, err := parseActionFromProto(b)
if err != nil {
t.Fatalf("parseActionFromProto: %v", err)
}

if pa.name != "test-action" {
t.Errorf("name = %q, want %q", pa.name, "test-action")
}
if pa.message == nil {
t.Error("message is nil")
}
}

func TestParseActionFromProto_FallbackName(t *testing.T) {
// When no name is set the type URL is used as the display name.
base := &api.Base{}
base.SetMetadata(&api.Metadata{})
base.GetMetadata().SetType("aks.flex.components.linux.ConfigureBaseOS")

b, err := proto.Marshal(base)
if err != nil {
t.Fatalf("proto.Marshal: %v", err)
}

pa, err := parseActionFromProto(b)
if err != nil {
t.Fatalf("parseActionFromProto: %v", err)
}

want := "aks.flex.components.linux.ConfigureBaseOS"
if pa.name != want {
t.Errorf("name = %q, want %q", pa.name, want)
}
}

func TestParseActionFromProto_UnknownType(t *testing.T) {
base := &api.Base{}
base.SetMetadata(&api.Metadata{})
base.GetMetadata().SetType("does.not.Exist")

b, err := proto.Marshal(base)
if err != nil {
t.Fatalf("proto.Marshal: %v", err)
}

if _, err := parseActionFromProto(b); err == nil {
t.Error("expected error for unknown type, got nil")
}
}

func TestParseActions_BinaryProtoRoundTrip(t *testing.T) {
// A real serialized proto message must be routed to the binary proto parser.
base := &api.Base{}
base.SetMetadata(&api.Metadata{})
base.GetMetadata().SetType("aks.flex.components.linux.ConfigureBaseOS")
base.GetMetadata().SetName("round-trip")

b, err := proto.Marshal(base)
if err != nil {
t.Fatalf("proto.Marshal: %v", err)
}

parsed, err := parseActions(b)
if err != nil {
t.Fatalf("parseActions: %v", err)
}
if len(parsed) != 1 {
t.Fatalf("len(parsed) = %d, want 1", len(parsed))
}
if parsed[0].name != "round-trip" {
t.Errorf("name = %q, want %q", parsed[0].name, "round-trip")
}
}
Loading