diff --git a/.gitignore b/.gitignore index a2d044d..b089932 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,9 @@ # Test binary, built with `go test -c` *.test +# Test Files +test_* + # Output of the go coverage tool, specifically when used with LiteIDE *.out diff --git a/README.md b/README.md index 3aede34..b26ec5c 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ Looking to contribute? Check out: šŸ”„ **Multiple LLM Support** - Choose between Google Gemini, Grok, Claude, ChatGPT, or Ollama (local) šŸ“ **Context-Aware** - Analyzes staged and unstaged changes šŸ“‹ **Auto-Copy to Clipboard** - Generated messages are automatically copied for instant use +šŸŽ›ļø **Interactive Review Flow** - Accept, regenerate with new styles, or open the message in your editor before committing šŸ“Š **File Statistics Display** - Visual preview of changed files and line counts šŸš€ **Easy to Use** - Simple CLI interface with beautiful terminal UI ⚔ **Fast** - Quick generation of commit messages @@ -62,7 +63,6 @@ You can use **Google Gemini**, **Grok**, **Claude**, **ChatGPT**, or **Ollama** echo 'export PATH=$PATH:/path/to/commit-msg' >> ~/.bashrc # or ~/.zshrc ``` - ### Option 2: Build from Source Requirements: Go 1.23.4 or higher @@ -105,6 +105,7 @@ go run cmd/commit-msg/main.go . ```bash commit llm setup ``` + Screenshot 2025-10-05 172731 Screenshot 2025-10-05 172748 @@ -113,12 +114,12 @@ go run cmd/commit-msg/main.go . ```bash commit llm update ``` + Screenshot 2025-10-05 172814 Screenshot 2025-10-05 172823 ### Example Workflow - ```bash # Make changes to your code echo "console.log('Hello World')" > app.js @@ -137,6 +138,17 @@ commit . # You can now paste it with Ctrl+V (or Cmd+V on macOS) ``` +### Interactive Commit Workflow + +Once the commit message is generated, the CLI now offers a quick review loop: + +- **Accept & copy** – use the message as-is (it still lands on your clipboard automatically) +- **Regenerate** – pick from presets like detailed summaries, casual tone, bug-fix emphasis, or provide custom instructions for the LLM +- **Edit in your editor** – open the message in `$GIT_EDITOR`, `$VISUAL`, `$EDITOR`, or a sensible fallback (`notepad` on Windows, `nano` elsewhere) +- **Exit** – leave without copying anything if the message isn't ready yet + +This makes it easy to tweak the tone, iterate on suggestions, or fine-tune the final wording before you commit. + ### Use Cases - šŸ“ Generate commit messages for staged changes @@ -162,20 +174,24 @@ commit . commit llm update ``` -**Set LLM as default** +### Set LLM as default + ```bash Select: Set Default ``` -**Change API Key** +### Change API Key + ```bash Select: Change API Key ``` -**Delete LLM** +### Delete LLM + ```bash Select: Delete ``` + --- ## Getting API Keys @@ -185,40 +201,35 @@ Select: Delete 1. Visit [Google AI Studio](https://makersuite.google.com/app/apikey) 2. Create a new API key - **Grok (X.AI):** 1. Visit [X.AI Console](https://console.x.ai/) 2. Generate an API key - **Groq:** 1. Sign up at [Groq Cloud](https://console.groq.com/) 2. Create an API key - **Claude (Anthropic):** 1. Visit the [Anthropic Console](https://console.anthropic.com/) 2. Create a new API key - **OpenAI (ChatGPT):** 1. Visit [OpenAI Platform](https://platform.openai.com/api-keys) 2. Create a new API key - **Ollama (Local LLM):** 1. Install Ollama: Visit [Ollama.ai](https://ollama.ai/) and follow installation instructions 2. Start Ollama: `ollama serve` -3. Pull a model: `ollama pull llama3` +3. Pull a model: `ollama pull llama3.1` 4. Set environment variables: ```bash export COMMIT_LLM=ollama - export OLLAMA_MODEL=llama3 # llama3 by default + export OLLAMA_MODEL=llama3.1 # llama3.1 by default ``` --- diff --git a/cmd/cli/createMsg.go b/cmd/cli/createMsg.go index 7c062b1..5c08578 100644 --- a/cmd/cli/createMsg.go +++ b/cmd/cli/createMsg.go @@ -1,7 +1,12 @@ package cmd import ( + "errors" + "fmt" "os" + "os/exec" + "runtime" + "strings" "github.com/atotto/clipboard" "github.com/dfanso/commit-msg/cmd/cli/store" @@ -15,11 +20,13 @@ import ( "github.com/dfanso/commit-msg/internal/ollama" "github.com/dfanso/commit-msg/internal/stats" "github.com/dfanso/commit-msg/pkg/types" + "github.com/google/shlex" "github.com/pterm/pterm" ) +// CreateCommitMsg launches the interactive flow for reviewing, regenerating, +// editing, and accepting AI-generated commit messages in the current repo. func CreateCommitMsg() { - // Validate COMMIT_LLM and required API keys useLLM, err := store.DefaultLLMKey() if err != nil { @@ -43,32 +50,24 @@ func CreateCommitMsg() { os.Exit(1) } - // Create a minimal config for the API config := &types.Config{ GrokAPI: "https://api.x.ai/v1/chat/completions", } - // Create a repo config for the current directory - repoConfig := types.RepoConfig{ - Path: currentDir, - } + repoConfig := types.RepoConfig{Path: currentDir} - // Get file statistics before fetching changes fileStats, err := stats.GetFileStatistics(&repoConfig) if err != nil { pterm.Error.Printf("Failed to get file statistics: %v\n", err) os.Exit(1) } - // Display header pterm.DefaultHeader.WithFullWidth(). WithBackgroundStyle(pterm.NewStyle(pterm.BgDarkGray)). WithTextStyle(pterm.NewStyle(pterm.FgLightWhite)). Println("Commit Message Generator") pterm.Println() - - // Display file statistics with icons display.ShowFileStatistics(fileStats) if fileStats.TotalFiles == 0 { @@ -80,7 +79,6 @@ func CreateCommitMsg() { return } - // Get the changes changes, err := git.GetChanges(&repoConfig) if err != nil { pterm.Error.Printf("Failed to get Git changes: %v\n", err) @@ -97,8 +95,6 @@ func CreateCommitMsg() { } pterm.Println() - - // Show generating spinner spinnerGenerating, err := pterm.DefaultSpinner. WithSequence("ā ‹", "ā ™", "ā ¹", "ā ø", "ā ¼", "ā “", "ā ¦", "ā §", "ā ‡", "ā "). Start("Generating commit message with " + commitLLM.String() + "...") @@ -107,65 +103,325 @@ func CreateCommitMsg() { os.Exit(1) } - var commitMsg string + attempt := 1 + commitMsg, err := generateMessage(commitLLM, config, changes, apiKey, withAttempt(nil, attempt)) + if err != nil { + spinnerGenerating.Fail("Failed to generate commit message") + displayProviderError(commitLLM, err) + os.Exit(1) + } - switch commitLLM { + spinnerGenerating.Success("Commit message generated successfully!") - case types.ProviderGemini: - commitMsg, err = gemini.GenerateCommitMessage(config, changes, apiKey) + currentMessage := strings.TrimSpace(commitMsg) + currentStyleLabel := stylePresets[0].Label + var currentStyleOpts *types.GenerationOptions + accepted := false + finalMessage := "" + +interactionLoop: + for { + pterm.Println() + display.ShowCommitMessage(currentMessage) + + action, err := promptActionSelection() + if err != nil { + pterm.Error.Printf("Failed to read selection: %v\n", err) + return + } - case types.ProviderOpenAI: - commitMsg, err = chatgpt.GenerateCommitMessage(config, changes, apiKey) + switch action { + case actionAcceptOption: + finalMessage = strings.TrimSpace(currentMessage) + if finalMessage == "" { + pterm.Warning.Println("Commit message is empty; please edit or regenerate before accepting.") + continue + } + if err := clipboard.WriteAll(finalMessage); err != nil { + pterm.Warning.Printf("Could not copy to clipboard: %v\n", err) + } else { + pterm.Success.Println("Commit message copied to clipboard!") + } + accepted = true + break interactionLoop + case actionRegenerateOption: + opts, styleLabel, err := promptStyleSelection(currentStyleLabel, currentStyleOpts) + if errors.Is(err, errSelectionCancelled) { + continue + } + if err != nil { + pterm.Error.Printf("Failed to select style: %v\n", err) + continue + } + if styleLabel != "" { + currentStyleLabel = styleLabel + } + currentStyleOpts = opts + nextAttempt := attempt + 1 + generationOpts := withAttempt(currentStyleOpts, nextAttempt) + spinner, err := pterm.DefaultSpinner. + WithSequence("ā ‹", "ā ™", "ā ¹", "ā ø", "ā ¼", "ā “", "ā ¦", "ā §", "ā ‡", "ā "). + Start(fmt.Sprintf("Regenerating commit message (%s)...", currentStyleLabel)) + if err != nil { + pterm.Error.Printf("Failed to start spinner: %v\n", err) + continue + } + updatedMessage, genErr := generateMessage(commitLLM, config, changes, apiKey, generationOpts) + if genErr != nil { + spinner.Fail("Regeneration failed") + displayProviderError(commitLLM, genErr) + continue + } + spinner.Success("Commit message regenerated!") + attempt = nextAttempt + currentMessage = strings.TrimSpace(updatedMessage) + case actionEditOption: + edited, editErr := editCommitMessage(currentMessage) + if editErr != nil { + pterm.Error.Printf("Failed to edit commit message: %v\n", editErr) + continue + } + if strings.TrimSpace(edited) == "" { + pterm.Warning.Println("Edited commit message is empty; keeping previous message.") + continue + } + currentMessage = strings.TrimSpace(edited) + case actionExitOption: + pterm.Info.Println("Exiting without copying commit message.") + return + default: + pterm.Warning.Printf("Unknown selection: %s\n", action) + } + } + + if !accepted { + return + } + + pterm.Println() + display.ShowChangesPreview(fileStats) +} + +type styleOption struct { + Label string + Instruction string +} + +const ( + actionAcceptOption = "Accept and copy commit message" + actionRegenerateOption = "Regenerate with different tone/style" + actionEditOption = "Edit message in editor" + actionExitOption = "Discard and exit" + customStyleOption = "Custom instructions (enter your own)" + styleBackOption = "Back to actions" +) + +var ( + actionOptions = []string{actionAcceptOption, actionRegenerateOption, actionEditOption, actionExitOption} + stylePresets = []styleOption{ + {Label: "Concise conventional (default)", Instruction: ""}, + {Label: "Detailed summary (adds bullet list)", Instruction: "Produce a conventional commit subject line followed by a blank line and bullet points summarizing the key changes."}, + {Label: "Casual tone", Instruction: "Write the commit message in a friendly, conversational tone while still clearly explaining the changes."}, + {Label: "Bug fix emphasis", Instruction: "Highlight the bug being fixed, reference the root cause when possible, and describe the remedy in the body."}, + } + errSelectionCancelled = errors.New("selection cancelled") +) +func generateMessage(provider types.LLMProvider, config *types.Config, changes string, apiKey string, opts *types.GenerationOptions) (string, error) { + switch provider { + case types.ProviderGemini: + return gemini.GenerateCommitMessage(config, changes, apiKey, opts) + case types.ProviderOpenAI: + return chatgpt.GenerateCommitMessage(config, changes, apiKey, opts) case types.ProviderClaude: - commitMsg, err = claude.GenerateCommitMessage(config, changes, apiKey) + return claude.GenerateCommitMessage(config, changes, apiKey, opts) case types.ProviderGroq: - commitMsg, err = groq.GenerateCommitMessage(config, changes, apiKey) + return groq.GenerateCommitMessage(config, changes, apiKey, opts) case types.ProviderOllama: - model := "llama3:latest" - - commitMsg, err = ollama.GenerateCommitMessage(config, changes, apiKey, model) + url := apiKey + if strings.TrimSpace(url) == "" { + url = os.Getenv("OLLAMA_URL") + if url == "" { + url = "http://localhost:11434/api/generate" + } + } + model := os.Getenv("OLLAMA_MODEL") + if model == "" { + model = "llama3.1" + } + return ollama.GenerateCommitMessage(config, changes, url, model, opts) default: - commitMsg, err = grok.GenerateCommitMessage(config, changes, apiKey) + return grok.GenerateCommitMessage(config, changes, apiKey, opts) } +} +func promptActionSelection() (string, error) { + return pterm.DefaultInteractiveSelect. + WithOptions(actionOptions). + WithDefaultOption(actionAcceptOption). + Show() +} + +func promptStyleSelection(currentLabel string, currentOpts *types.GenerationOptions) (*types.GenerationOptions, string, error) { + options := make([]string, 0, len(stylePresets)+3) + foundCurrent := false + for _, preset := range stylePresets { + options = append(options, preset.Label) + if preset.Label == currentLabel { + foundCurrent = true + } + } + if currentOpts != nil && currentLabel != "" && !foundCurrent { + options = append(options, currentLabel) + } + options = append(options, customStyleOption, styleBackOption) + + selector := pterm.DefaultInteractiveSelect.WithOptions(options) + if currentLabel != "" { + selector = selector.WithDefaultOption(currentLabel) + } + + choice, err := selector.Show() if err != nil { - spinnerGenerating.Fail("Failed to generate commit message") - switch commitLLM { - case types.ProviderGemini: - pterm.Error.Printf("Gemini API error. Check your GEMINI_API_KEY environment variable or run: commit llm setup\n") - case types.ProviderOpenAI: - pterm.Error.Printf("OpenAI API error. Check your OPENAI_API_KEY environment variable or run: commit llm setup\n") - case types.ProviderClaude: - pterm.Error.Printf("Claude API error. Check your CLAUDE_API_KEY environment variable or run: commit llm setup\n") - case types.ProviderGroq: - pterm.Error.Printf("Groq API error. Check your GROQ_API_KEY environment variable or run: commit llm setup\n") - case types.ProviderGrok: - pterm.Error.Printf("Grok API error. Check your GROK_API_KEY environment variable or run: commit llm setup\n") - default: - pterm.Error.Printf("LLM API error: %v\n", err) + return currentOpts, currentLabel, err + } + + switch choice { + case styleBackOption: + return currentOpts, currentLabel, errSelectionCancelled + case customStyleOption: + text, err := pterm.DefaultInteractiveTextInput. + WithDefaultText("Describe the tone or style you're looking for"). + Show() + if err != nil { + return currentOpts, currentLabel, err + } + text = strings.TrimSpace(text) + if text == "" { + return currentOpts, currentLabel, errSelectionCancelled + } + return &types.GenerationOptions{StyleInstruction: text}, formatCustomStyleLabel(text), nil + default: + for _, preset := range stylePresets { + if choice == preset.Label { + if strings.TrimSpace(preset.Instruction) == "" { + return nil, preset.Label, nil + } + return &types.GenerationOptions{StyleInstruction: preset.Instruction}, preset.Label, nil + } + } + if currentOpts != nil && choice == currentLabel { + clone := *currentOpts + return &clone, currentLabel, nil } - os.Exit(1) } - spinnerGenerating.Success("Commit message generated successfully!") + if currentOpts != nil && currentLabel != "" { + clone := *currentOpts + return &clone, currentLabel, nil + } + return nil, currentLabel, nil +} - pterm.Println() +func editCommitMessage(initial string) (string, error) { + command, args, err := resolveEditorCommand() + if err != nil { + return "", err + } - // Display the commit message in a styled panel - display.ShowCommitMessage(commitMsg) + tmpFile, err := os.CreateTemp("", "commit-msg-*.txt") + if err != nil { + return "", err + } + defer os.Remove(tmpFile.Name()) + + if _, err := tmpFile.WriteString(strings.TrimSpace(initial) + "\n"); err != nil { + tmpFile.Close() + return "", err + } + + if err := tmpFile.Close(); err != nil { + return "", err + } - // Copy to clipboard - err = clipboard.WriteAll(commitMsg) + cmdArgs := append(args, tmpFile.Name()) + cmd := exec.Command(command, cmdArgs...) + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + if err := cmd.Run(); err != nil { + return "", fmt.Errorf("editor exited with error: %w", err) + } + + content, err := os.ReadFile(tmpFile.Name()) if err != nil { - pterm.Warning.Printf("Could not copy to clipboard: %v\n", err) - } else { - pterm.Success.Println("Commit message copied to clipboard!") + return "", err } - pterm.Println() + return strings.TrimSpace(string(content)), nil +} - // Display changes preview - display.ShowChangesPreview(fileStats) +func resolveEditorCommand() (string, []string, error) { + candidates := []string{ + os.Getenv("GIT_EDITOR"), + os.Getenv("VISUAL"), + os.Getenv("EDITOR"), + } + + for _, candidate := range candidates { + candidate = strings.TrimSpace(candidate) + if candidate == "" { + continue + } + parts, err := shlex.Split(candidate) + if err != nil { + return "", nil, fmt.Errorf("failed to parse editor command %q: %w", candidate, err) + } + if len(parts) == 0 { + continue + } + return parts[0], parts[1:], nil + } + + if runtime.GOOS == "windows" { + return "notepad", nil, nil + } + return "nano", nil, nil +} + +func formatCustomStyleLabel(instruction string) string { + trimmed := strings.TrimSpace(instruction) + runes := []rune(trimmed) + if len(runes) > 40 { + return fmt.Sprintf("Custom: %s…", string(runes[:37])) + } + return fmt.Sprintf("Custom: %s", trimmed) +} + +func withAttempt(styleOpts *types.GenerationOptions, attempt int) *types.GenerationOptions { + if styleOpts == nil { + return &types.GenerationOptions{Attempt: attempt} + } + clone := *styleOpts + clone.Attempt = attempt + return &clone +} + +func displayProviderError(provider types.LLMProvider, err error) { + switch provider { + case types.ProviderGemini: + pterm.Error.Printf("Gemini API error: %v. Check your GEMINI_API_KEY environment variable or run: commit llm setup\n", err) + case types.ProviderOpenAI: + pterm.Error.Printf("OpenAI API error: %v. Check your OPENAI_API_KEY environment variable or run: commit llm setup\n", err) + case types.ProviderClaude: + pterm.Error.Printf("Claude API error: %v. Check your CLAUDE_API_KEY environment variable or run: commit llm setup\n", err) + case types.ProviderGroq: + pterm.Error.Printf("Groq API error: %v. Check your GROQ_API_KEY environment variable or run: commit llm setup\n", err) + case types.ProviderGrok: + pterm.Error.Printf("Grok API error: %v. Check your GROK_API_KEY environment variable or run: commit llm setup\n", err) + default: + pterm.Error.Printf("LLM API error: %v\n", err) + } } diff --git a/cmd/cli/llmSetup.go b/cmd/cli/llmSetup.go index f551c46..4ef57e3 100644 --- a/cmd/cli/llmSetup.go +++ b/cmd/cli/llmSetup.go @@ -9,6 +9,8 @@ import ( "github.com/manifoldco/promptui" ) +// SetupLLM walks the user through selecting an LLM provider and storing the +// corresponding API key or endpoint configuration. func SetupLLM() error { providers := types.GetSupportedProviderStrings() @@ -67,6 +69,8 @@ func SetupLLM() error { return nil } +// UpdateLLM lets the user switch defaults, rotate API keys, or delete stored +// LLM provider configurations. func UpdateLLM() error { SavedModels, err := store.ListSavedModels() diff --git a/cmd/cli/store/store.go b/cmd/cli/store/store.go index b18c1c2..efa2b98 100644 --- a/cmd/cli/store/store.go +++ b/cmd/cli/store/store.go @@ -1,3 +1,4 @@ +// Package store persists user-selected LLM providers and credentials. package store import ( @@ -11,16 +12,19 @@ import ( "github.com/dfanso/commit-msg/pkg/types" ) +// LLMProvider represents a single stored LLM provider and its credential. type LLMProvider struct { LLM types.LLMProvider `json:"model"` APIKey string `json:"api_key"` } +// Config describes the on-disk structure for all saved LLM providers. type Config struct { Default types.LLMProvider `json:"default"` LLMProviders []LLMProvider `json:"models"` } +// Save persists or updates an LLM provider entry, marking it as the default. func Save(LLMConfig LLMProvider) error { cfg := Config{ @@ -139,6 +143,7 @@ func getConfigPath() (string, error) { } +// DefaultLLMKey returns the currently selected default LLM provider, if any. func DefaultLLMKey() (*LLMProvider, error) { var cfg Config @@ -179,6 +184,7 @@ func DefaultLLMKey() (*LLMProvider, error) { return nil, errors.New("not found default model in config") } +// ListSavedModels loads all persisted LLM provider configurations. func ListSavedModels() (*Config, error) { var cfg Config @@ -211,6 +217,7 @@ func ListSavedModels() (*Config, error) { } +// ChangeDefault updates the default LLM provider selection in the config. func ChangeDefault(Model types.LLMProvider) error { var cfg Config @@ -237,6 +244,17 @@ func ChangeDefault(Model types.LLMProvider) error { } } + found := false + for _, p := range cfg.LLMProviders { + if p.LLM == Model { + found = true + break + } + } + if !found { + return fmt.Errorf("cannot set default to %s: no saved entry", Model.String()) + } + cfg.Default = Model data, err = json.MarshalIndent(cfg, "", " ") @@ -247,6 +265,7 @@ func ChangeDefault(Model types.LLMProvider) error { return os.WriteFile(configPath, data, 0600) } +// DeleteModel removes the specified provider from the saved configuration. func DeleteModel(Model types.LLMProvider) error { var cfg Config @@ -300,6 +319,7 @@ func DeleteModel(Model types.LLMProvider) error { } } +// UpdateAPIKey rotates the credential for an existing provider entry. func UpdateAPIKey(Model types.LLMProvider, APIKey string) error { var cfg Config @@ -326,12 +346,18 @@ func UpdateAPIKey(Model types.LLMProvider, APIKey string) error { } } + updated := false for i, p := range cfg.LLMProviders { if p.LLM == Model { cfg.LLMProviders[i].APIKey = APIKey + updated = true } } + if !updated { + return fmt.Errorf("no saved entry for %s to update", Model.String()) + } + data, err = json.MarshalIndent(cfg, "", " ") if err != nil { return err diff --git a/go.mod b/go.mod index 99570d6..5860905 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ toolchain go1.24.7 require ( github.com/atotto/clipboard v0.1.4 + github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 github.com/google/generative-ai-go v0.19.0 github.com/manifoldco/promptui v0.9.0 github.com/openai/openai-go/v3 v3.0.1 diff --git a/go.sum b/go.sum index fe4be9b..7507149 100644 --- a/go.sum +++ b/go.sum @@ -58,6 +58,8 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/enterprise-certificate-proxy v0.3.4 h1:XYIDZApgAnrN1c855gTgghdIA6Stxb52D5RnLI1SLyw= diff --git a/internal/chatgpt/chatgpt.go b/internal/chatgpt/chatgpt.go index ba1cf2e..714c046 100644 --- a/internal/chatgpt/chatgpt.go +++ b/internal/chatgpt/chatgpt.go @@ -10,11 +10,13 @@ import ( "github.com/dfanso/commit-msg/pkg/types" ) -func GenerateCommitMessage(config *types.Config, changes string, apiKey string) (string, error) { +// GenerateCommitMessage calls OpenAI's chat completions API to turn the provided +// repository changes into a polished git commit message. +func GenerateCommitMessage(config *types.Config, changes string, apiKey string, opts *types.GenerationOptions) (string, error) { client := openai.NewClient(option.WithAPIKey(apiKey)) - prompt := fmt.Sprintf("%s\n\n%s", types.CommitPrompt, changes) + prompt := types.BuildCommitPrompt(changes, opts) resp, err := client.Chat.Completions.New(context.TODO(), openai.ChatCompletionNewParams{ Messages: []openai.ChatCompletionMessageParamUnion{ diff --git a/internal/claude/claude.go b/internal/claude/claude.go index 1ed0dc6..5a0a02b 100644 --- a/internal/claude/claude.go +++ b/internal/claude/claude.go @@ -10,17 +10,14 @@ import ( "github.com/dfanso/commit-msg/pkg/types" ) +// ClaudeRequest describes the payload sent to Anthropic's Claude messages API. type ClaudeRequest struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - MaxTokens int `json:"max_tokens"` -} - -type Message struct { - Role string `json:"role"` - Content string `json:"content"` + Model string `json:"model"` + Messages []types.Message `json:"messages"` + MaxTokens int `json:"max_tokens"` } +// ClaudeResponse captures the subset of fields used from Anthropic responses. type ClaudeResponse struct { ID string `json:"id"` Type string `json:"type"` @@ -30,14 +27,15 @@ type ClaudeResponse struct { } `json:"content"` } -func GenerateCommitMessage(config *types.Config, changes string, apiKey string) (string, error) { +// GenerateCommitMessage produces a commit summary using Anthropic's Claude API. +func GenerateCommitMessage(config *types.Config, changes string, apiKey string, opts *types.GenerationOptions) (string, error) { - prompt := fmt.Sprintf("%s\n\n%s", types.CommitPrompt, changes) + prompt := types.BuildCommitPrompt(changes, opts) reqBody := ClaudeRequest{ Model: "claude-3-5-sonnet-20241022", MaxTokens: 200, - Messages: []Message{ + Messages: []types.Message{ { Role: "user", Content: prompt, diff --git a/internal/gemini/gemini.go b/internal/gemini/gemini.go index b231868..0606d49 100644 --- a/internal/gemini/gemini.go +++ b/internal/gemini/gemini.go @@ -10,9 +10,11 @@ import ( "github.com/dfanso/commit-msg/pkg/types" ) -func GenerateCommitMessage(config *types.Config, changes string, apiKey string) (string, error) { +// GenerateCommitMessage asks Google Gemini to author a commit message for the +// supplied repository changes and optional style instructions. +func GenerateCommitMessage(config *types.Config, changes string, apiKey string, opts *types.GenerationOptions) (string, error) { // Prepare request to Gemini API - prompt := fmt.Sprintf("%s\n\n%s", types.CommitPrompt, changes) + prompt := types.BuildCommitPrompt(changes, opts) // Create context and client ctx := context.Background() diff --git a/internal/grok/grok.go b/internal/grok/grok.go index a5acc25..65ebd4e 100644 --- a/internal/grok/grok.go +++ b/internal/grok/grok.go @@ -5,16 +5,43 @@ import ( "crypto/tls" "encoding/json" "fmt" - "io/ioutil" + "io" "net/http" + "sync" "time" "github.com/dfanso/commit-msg/pkg/types" ) -func GenerateCommitMessage(config *types.Config, changes string, apiKey string) (string, error) { +var ( + grokClientOnce sync.Once + grokClient *http.Client +) + +func getHTTPClient() *http.Client { + grokClientOnce.Do(func() { + transport := &http.Transport{ + TLSHandshakeTimeout: 10 * time.Second, + MaxIdleConns: 10, + IdleConnTimeout: 30 * time.Second, + DisableCompression: true, + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: false, + }, + } + grokClient = &http.Client{ + Timeout: 30 * time.Second, + Transport: transport, + } + }) + return grokClient +} + +// GenerateCommitMessage calls X.AI's Grok API to create a commit message from +// the provided Git diff and generation options. +func GenerateCommitMessage(config *types.Config, changes string, apiKey string, opts *types.GenerationOptions) (string, error) { // Prepare request to X.AI (Grok) API - prompt := fmt.Sprintf("%s\n\n%s", types.CommitPrompt, changes) + prompt := types.BuildCommitPrompt(changes, opts) request := types.GrokRequest{ @@ -44,22 +71,7 @@ func GenerateCommitMessage(config *types.Config, changes string, apiKey string) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey)) - // Configure HTTP client with improved TLS settings - transport := &http.Transport{ - TLSHandshakeTimeout: 10 * time.Second, - MaxIdleConns: 10, - IdleConnTimeout: 30 * time.Second, - DisableCompression: true, - // Add TLS config to handle server name mismatch - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: false, // Keep this false for security - }, - } - - client := &http.Client{ - Timeout: 30 * time.Second, - Transport: transport, - } + client := getHTTPClient() resp, err := client.Do(req) if err != nil { return "", err @@ -68,7 +80,7 @@ func GenerateCommitMessage(config *types.Config, changes string, apiKey string) // Check response status if resp.StatusCode != http.StatusOK { - bodyBytes, _ := ioutil.ReadAll(resp.Body) + bodyBytes, _ := io.ReadAll(resp.Body) return "", fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(bodyBytes)) } diff --git a/internal/groq/groq.go b/internal/groq/groq.go index cee46bd..b7a1be4 100644 --- a/internal/groq/groq.go +++ b/internal/groq/groq.go @@ -43,12 +43,12 @@ var ( ) // GenerateCommitMessage calls Groq's OpenAI-compatible chat completions API. -func GenerateCommitMessage(_ *types.Config, changes string, apiKey string) (string, error) { +func GenerateCommitMessage(_ *types.Config, changes string, apiKey string, opts *types.GenerationOptions) (string, error) { if changes == "" { return "", fmt.Errorf("no changes provided for commit message generation") } - prompt := fmt.Sprintf("%s\n\n%s", types.CommitPrompt, changes) + prompt := types.BuildCommitPrompt(changes, opts) model := os.Getenv("GROQ_MODEL") if model == "" { diff --git a/internal/groq/groq_test.go b/internal/groq/groq_test.go index 754e4de..c1e2557 100644 --- a/internal/groq/groq_test.go +++ b/internal/groq/groq_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "strings" "testing" "github.com/dfanso/commit-msg/pkg/types" @@ -73,7 +74,7 @@ func TestGenerateCommitMessageSuccess(t *testing.T) { t.Fatalf("failed to write response: %v", err) } }, func() { - msg, err := GenerateCommitMessage(&types.Config{}, "diff", "test-key") + msg, err := GenerateCommitMessage(&types.Config{}, "diff", "test-key", nil) if err != nil { t.Fatalf("GenerateCommitMessage returned error: %v", err) } @@ -89,7 +90,7 @@ func TestGenerateCommitMessageNonOK(t *testing.T) { withTestServer(t, func(w http.ResponseWriter, r *http.Request) { http.Error(w, `{"error":"bad things"}`, http.StatusBadGateway) }, func() { - _, err := GenerateCommitMessage(&types.Config{}, "changes", "key") + _, err := GenerateCommitMessage(&types.Config{}, "changes", "key", nil) if err == nil { t.Fatal("expected error but got nil") } @@ -100,7 +101,45 @@ func TestGenerateCommitMessageEmptyChanges(t *testing.T) { t.Setenv("GROQ_MODEL", "") t.Setenv("GROQ_API_URL", "") - if _, err := GenerateCommitMessage(&types.Config{}, "", "key"); err == nil { + if _, err := GenerateCommitMessage(&types.Config{}, "", "key", nil); err == nil { t.Fatal("expected error for empty changes") } } + +func TestGenerateCommitMessageIncludesStyleInstruction(t *testing.T) { + recorded := "" + + withTestServer(t, func(w http.ResponseWriter, r *http.Request) { + var payload capturedRequest + if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { + t.Fatalf("failed to decode request: %v", err) + } + recorded = payload.Messages[1].Content + + resp := chatResponse{ + Choices: []chatChoice{ + {Message: chatMessage{Role: "assistant", Content: "feat: add style support"}}, + }, + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(resp); err != nil { + t.Fatalf("failed to write response: %v", err) + } + }, func() { + opts := &types.GenerationOptions{StyleInstruction: "Use a casual tone.", Attempt: 2} + if _, err := GenerateCommitMessage(&types.Config{}, "diff", "key", opts); err != nil { + t.Fatalf("GenerateCommitMessage returned error: %v", err) + } + }) + + if !strings.Contains(recorded, "Additional instructions:") { + t.Fatalf("expected request payload to contain additional instructions, got: %q", recorded) + } + if !strings.Contains(recorded, "Use a casual tone.") { + t.Fatalf("expected request payload to contain custom instruction, got: %q", recorded) + } + if !strings.Contains(recorded, "Regeneration context:") { + t.Fatalf("expected request payload to contain regeneration context, got: %q", recorded) + } +} diff --git a/internal/ollama/ollama.go b/internal/ollama/ollama.go index 2094485..cb4f26e 100644 --- a/internal/ollama/ollama.go +++ b/internal/ollama/ollama.go @@ -10,24 +10,28 @@ import ( "github.com/dfanso/commit-msg/pkg/types" ) +// OllamaRequest captures the prompt payload sent to an Ollama HTTP endpoint. type OllamaRequest struct { Model string `json:"model"` Prompt string `json:"prompt"` } +// OllamaResponse represents the non-streaming response from Ollama. type OllamaResponse struct { Response string `json:"response"` Done bool `json:"done"` } -func GenerateCommitMessage(_ *types.Config, changes string, url string, model string) (string, error) { +// GenerateCommitMessage uses a locally hosted Ollama model to draft a commit +// message from repository changes and optional style guidance. +func GenerateCommitMessage(_ *types.Config, changes string, url string, model string, opts *types.GenerationOptions) (string, error) { // Use llama3:latest as the default model if model == "" { model = "llama3:latest" } // Preparing the prompt - prompt := fmt.Sprintf("%s\n\n%s", types.CommitPrompt, changes) + prompt := types.BuildCommitPrompt(changes, opts) // Generating the request body - add stream: false for non-streaming response reqBody := map[string]interface{}{ diff --git a/pkg/types/options.go b/pkg/types/options.go new file mode 100644 index 0000000..d7ed46d --- /dev/null +++ b/pkg/types/options.go @@ -0,0 +1,10 @@ +package types + +// GenerationOptions controls how commit messages should be produced by LLM providers. +type GenerationOptions struct { + // StyleInstruction contains optional tone/style guidance appended to the base prompt. + StyleInstruction string + // Attempt records the 1-indexed attempt number for this generation request. + // Attempt > 1 signals that the LLM should provide an alternative output. + Attempt int +} diff --git a/pkg/types/prompt.go b/pkg/types/prompt.go index 5b9a2f5..cd55192 100644 --- a/pkg/types/prompt.go +++ b/pkg/types/prompt.go @@ -1,5 +1,12 @@ package types +import ( + "fmt" + "strings" +) + +// CommitPrompt is the base instruction template sent to LLM providers before +// appending repository changes and optional style guidance. var CommitPrompt = `I need a concise git commit message based on the following changes from my Git repository. Please generate a commit message that: 1. Starts with a verb in the present tense (e.g., "Add", "Fix", "Update", "Feat", "Refactor", etc.) @@ -17,3 +24,28 @@ here is a sample commit msgs: - variable. Also updates dependencies and README.' Here are the changes: ` + +// BuildCommitPrompt constructs the prompt that will be sent to the LLM, applying +// any optional tone/style instructions before appending the repository changes. +func BuildCommitPrompt(changes string, opts *GenerationOptions) string { + var builder strings.Builder + builder.WriteString(CommitPrompt) + + if opts != nil { + if opts.Attempt > 1 { + builder.WriteString("\n\nRegeneration context:\n") + builder.WriteString(fmt.Sprintf("- This is attempt #%d.\n", opts.Attempt)) + builder.WriteString("- Provide a commit message that is meaningfully different from earlier attempts.\n") + } + + if strings.TrimSpace(opts.StyleInstruction) != "" { + builder.WriteString("\n\nAdditional instructions:\n") + builder.WriteString(strings.TrimSpace(opts.StyleInstruction)) + } + } + + builder.WriteString("\n\n") + builder.WriteString(changes) + + return builder.String() +} diff --git a/pkg/types/types.go b/pkg/types/types.go index 2e6f848..8bf7233 100644 --- a/pkg/types/types.go +++ b/pkg/types/types.go @@ -1,5 +1,7 @@ package types +// LLMProvider identifies the large language model backend used to author +// commit messages. type LLMProvider string const ( @@ -11,10 +13,12 @@ const ( ProviderOllama LLMProvider = "Ollama" ) +// String returns the provider identifier as a plain string. func (p LLMProvider) String() string { return string(p) } +// IsValid reports whether the provider is part of the supported set. func (p LLMProvider) IsValid() bool { switch p { case ProviderOpenAI, ProviderClaude, ProviderGemini, ProviderGrok, ProviderGroq, ProviderOllama: @@ -24,6 +28,7 @@ func (p LLMProvider) IsValid() bool { } } +// GetSupportedProviders returns all available provider enums. func GetSupportedProviders() []LLMProvider { return []LLMProvider{ ProviderOpenAI, @@ -35,6 +40,7 @@ func GetSupportedProviders() []LLMProvider { } } +// GetSupportedProviderStrings returns the human-friendly names for providers. func GetSupportedProviderStrings() []string { providers := GetSupportedProviders() strings := make([]string, len(providers)) @@ -44,24 +50,25 @@ func GetSupportedProviderStrings() []string { return strings } +// ParseLLMProvider converts a string into an LLMProvider enum when supported. func ParseLLMProvider(s string) (LLMProvider, bool) { provider := LLMProvider(s) return provider, provider.IsValid() } -// Configuration structure +// Config stores CLI-level configuration including named repositories. type Config struct { GrokAPI string `json:"grok_api"` Repos map[string]RepoConfig `json:"repos"` } -// Repository configuration +// RepoConfig tracks metadata for a configured Git repository. type RepoConfig struct { Path string `json:"path"` LastRun string `json:"last_run"` } -// Grok/X.AI API request structure +// GrokRequest represents a chat completion request sent to X.AI's API. type GrokRequest struct { Messages []Message `json:"messages"` Model string `json:"model"` @@ -69,12 +76,13 @@ type GrokRequest struct { Temperature float64 `json:"temperature"` } +// Message captures the role/content pairs exchanged with Grok. type Message struct { Role string `json:"role"` Content string `json:"content"` } -// Grok/X.AI API response structure +// GrokResponse contains the relevant fields parsed from X.AI responses. type GrokResponse struct { Message Message `json:"message,omitempty"` Choices []Choice `json:"choices,omitempty"` @@ -85,12 +93,14 @@ type GrokResponse struct { Usage UsageInfo `json:"usage,omitempty"` } +// Choice details a single response option returned by Grok. type Choice struct { Message Message `json:"message"` Index int `json:"index"` FinishReason string `json:"finish_reason"` } +// UsageInfo reports token usage statistics from Grok responses. type UsageInfo struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` diff --git a/pkg/types/types_test.go b/pkg/types/types_test.go index 710fff2..0231633 100644 --- a/pkg/types/types_test.go +++ b/pkg/types/types_test.go @@ -48,3 +48,58 @@ func TestCommitPromptContent(t *testing.T) { } } } + +func TestBuildCommitPromptDefault(t *testing.T) { + t.Parallel() + + changes := "diff --git a/main.go b/main.go" + prompt := BuildCommitPrompt(changes, nil) + + if !strings.HasSuffix(prompt, changes) { + t.Fatalf("expected prompt to end with changes, got %q", prompt) + } + + if strings.Contains(prompt, "Additional instructions:") { + t.Fatalf("expected no additional instructions block in default prompt") + } +} + +func TestBuildCommitPromptWithInstructions(t *testing.T) { + t.Parallel() + + changes := "diff --git a/main.go b/main.go" + options := &GenerationOptions{StyleInstruction: "Use a playful tone."} + prompt := BuildCommitPrompt(changes, options) + + if !strings.Contains(prompt, "Additional instructions:") { + t.Fatalf("expected prompt to contain additional instructions block") + } + + if !strings.Contains(prompt, options.StyleInstruction) { + t.Fatalf("expected prompt to include style instruction %q", options.StyleInstruction) + } + + if !strings.HasSuffix(prompt, changes) { + t.Fatalf("expected prompt to end with changes, got %q", prompt) + } +} + +func TestBuildCommitPromptWithAttempt(t *testing.T) { + t.Parallel() + + changes := "diff --git a/main.go b/main.go" + options := &GenerationOptions{Attempt: 3} + prompt := BuildCommitPrompt(changes, options) + + if !strings.Contains(prompt, "Regeneration context:") { + t.Fatalf("expected regeneration context section, got %q", prompt) + } + + if !strings.Contains(prompt, "attempt #3") { + t.Fatalf("expected attempt number to be mentioned, got %q", prompt) + } + + if !strings.HasSuffix(prompt, changes) { + t.Fatalf("expected prompt to end with changes, got %q", prompt) + } +}