diff --git a/environments/flags.go b/environments/flags.go index dc5bfcec..2293f11d 100644 --- a/environments/flags.go +++ b/environments/flags.go @@ -9,10 +9,12 @@ import ( const ( // CarverBlockSizeValue to configure size in bytes for carver blocks CarverBlockSizeValue string = "5120000" + // FlagGenericValue to use as generator for generic flags + FlagGenericValue string = `--{{ .FlagName }}={{ .FlagValue }}` // FlagTLSServerCerts for the --tls_server_certs flag - FlagTLSServerCerts string = `--tls_server_certs={{ .CertFile }}` + FlagNameTLSServerCerts string = `tls_server_certs` // FlagCarverBlockSize for the --carver_block_size flag - FlagCarverBlockSize string = `--carver_block_size={{ .BlockSize }}` + FlagNameCarverBlockSize string = `carver_block_size` // FlagsTemplate to generate flags for enrolling nodes FlagsTemplate string = ` --host_identifier=uuid @@ -63,27 +65,32 @@ func GenServerCertsFlag(certificatePath string) string { if certificatePath == "" { return "" } - data := struct { - CertFile string - }{ - CertFile: certificatePath, - } - return GenGenericFlag("servercerts", FlagTLSServerCerts, data) + return GenSingleFlag("servercerts", FlagNameTLSServerCerts, certificatePath) } // GenCarveBlockSizeFlag to generate the --carver_block_size flag func GenCarveBlockSizeFlag(blockSize string) string { + if blockSize == "" { + return "" + } + return GenSingleFlag("blocksize", FlagNameCarverBlockSize, blockSize) +} + +// GenSingleFlag to generate a generic flag to be used by osquery +func GenSingleFlag(tmplName, flagName, flagValue string) string { data := struct { - BlockSize string + FlagName string + FlagValue string }{ - BlockSize: blockSize, + FlagName: flagName, + FlagValue: flagValue, } - return GenGenericFlag("blocksize", FlagCarverBlockSize, data) + return ParseFlagTemplate(tmplName, FlagGenericValue, data) } -// GenGenericFlag to generate a generic flag to be used by osquery -func GenGenericFlag(flagName, flagConst string, data interface{}) string { - t, err := template.New(flagName).Parse(flagConst) +// ParseFlagTemplate to parse a flag template +func ParseFlagTemplate(tmplName, flagTemplate string, data interface{}) string { + t, err := template.New(tmplName).Parse(flagTemplate) if err != nil { return "" } @@ -114,7 +121,7 @@ func (environment *Environment) GenerateFlags(env TLSEnvironment, secretPath, ce FlagServerCerts: flagServerCerts, FlagCarverBlock: GenCarveBlockSizeFlag(CarverBlockSizeValue), } - return GenGenericFlag("flags", FlagsTemplate, data), nil + return ParseFlagTemplate("flags", FlagsTemplate, data), nil } // GenerateFlagsEnv to generate flags by environment name diff --git a/environments/flags_test.go b/environments/flags_test.go new file mode 100644 index 00000000..c0c2d72f --- /dev/null +++ b/environments/flags_test.go @@ -0,0 +1,73 @@ +package environments + +import ( + "testing" +) + +var testEnv = TLSEnvironment{} + +func TestGenServerCertsFlag(t *testing.T) { + t.Run("empty", func(t *testing.T) { + flag := GenServerCertsFlag("") + if flag != "" { + t.Errorf("Expected empty flag, got %s", flag) + } + }) + t.Run("not empty", func(t *testing.T) { + flag := GenServerCertsFlag("certificate") + if flag != "--tls_server_certs=certificate" { + t.Errorf("Expected flag --tls_server_certs=certificate, got %s", flag) + } + }) +} + +func TestGenCarveBlockSizeFlag(t *testing.T) { + t.Run("empty", func(t *testing.T) { + flag := GenCarveBlockSizeFlag("") + if flag != "" { + t.Errorf("Expected empty flag, got %s", flag) + } + }) + t.Run("not empty", func(t *testing.T) { + flag := GenCarveBlockSizeFlag("blockSize") + if flag != "--carver_block_size=blockSize" { + t.Errorf("Expected flag --carver_block_size=blockSize, got %s", flag) + } + }) +} + +func TestGenSingleFlag(t *testing.T) { + t.Run("empty", func(t *testing.T) { + flag := GenSingleFlag("tmplName", "flagName", "") + if flag != "--flagName=" { + t.Errorf("Expected --flagName=, got %s", flag) + } + }) + t.Run("not empty", func(t *testing.T) { + flag := GenSingleFlag("tmplName", "flagName", "flagValue") + if flag != "--flagName=flagValue" { + t.Errorf("Expected flag --flagName=flagValue, got %s", flag) + } + }) +} + +func TestParseFlagTemplate(t *testing.T) { + t.Run("empty data", func(t *testing.T) { + flag := ParseFlagTemplate("tmplName", "flagTemplate", nil) + if flag != "flagTemplate" { + t.Errorf("Expected empty flag, got %s", flag) + } + }) + t.Run("not empty data", func(t *testing.T) { + flag := ParseFlagTemplate("tmplName", "--{{ .Name }}={{ .Value }}", struct { + Name string + Value string + }{ + Name: "flagName", + Value: "flagValue", + }) + if flag != "--flagName=flagValue" { + t.Errorf("Expected flag --flagName=flagValue, got %s", flag) + } + }) +}