From 97470d5d910ea8c54efe082488de04a238f4edb2 Mon Sep 17 00:00:00 2001 From: Muneer320 Date: Tue, 7 Oct 2025 16:02:16 +0530 Subject: [PATCH 1/5] Feat: Add interactive review flow for commit messages --- .gitignore | 3 + README.md | 27 ++- cmd/cli/createMsg.go | 341 ++++++++++++++++++++++++++++++------ go.mod | 1 + go.sum | 2 + internal/chatgpt/chatgpt.go | 4 +- internal/claude/claude.go | 4 +- internal/gemini/gemini.go | 4 +- internal/grok/grok.go | 4 +- internal/groq/groq.go | 4 +- internal/groq/groq_test.go | 45 ++++- internal/ollama/ollama.go | 4 +- pkg/types/options.go | 10 ++ pkg/types/prompt.go | 30 ++++ pkg/types/types_test.go | 55 ++++++ 15 files changed, 463 insertions(+), 75 deletions(-) create mode 100644 pkg/types/options.go diff --git a/.gitignore b/.gitignore index a2d044d..b089932 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,9 @@ # Test binary, built with `go test -c` *.test +# Test Files +test_* + # Output of the go coverage tool, specifically when used with LiteIDE *.out diff --git a/README.md b/README.md index 3aede34..49be130 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ Looking to contribute? Check out: 🔄 **Multiple LLM Support** - Choose between Google Gemini, Grok, Claude, ChatGPT, or Ollama (local) 📝 **Context-Aware** - Analyzes staged and unstaged changes 📋 **Auto-Copy to Clipboard** - Generated messages are automatically copied for instant use +🎛️ **Interactive Review Flow** - Accept, regenerate with new styles, or open the message in your editor before committing 📊 **File Statistics Display** - Visual preview of changed files and line counts 🚀 **Easy to Use** - Simple CLI interface with beautiful terminal UI ⚡ **Fast** - Quick generation of commit messages @@ -62,7 +63,6 @@ You can use **Google Gemini**, **Grok**, **Claude**, **ChatGPT**, or **Ollama** echo 'export PATH=$PATH:/path/to/commit-msg' >> ~/.bashrc # or ~/.zshrc ``` - ### Option 2: Build from Source Requirements: Go 1.23.4 or higher @@ -105,6 +105,7 @@ go run cmd/commit-msg/main.go . ```bash commit llm setup ``` + Screenshot 2025-10-05 172731 Screenshot 2025-10-05 172748 @@ -113,12 +114,12 @@ go run cmd/commit-msg/main.go . ```bash commit llm update ``` + Screenshot 2025-10-05 172814 Screenshot 2025-10-05 172823 ### Example Workflow - ```bash # Make changes to your code echo "console.log('Hello World')" > app.js @@ -137,6 +138,17 @@ commit . # You can now paste it with Ctrl+V (or Cmd+V on macOS) ``` +### Interactive Commit Workflow + +Once the commit message is generated, the CLI now offers a quick review loop: + +- **Accept & copy** – use the message as-is (it still lands on your clipboard automatically) +- **Regenerate** – pick from presets like detailed summaries, casual tone, bug-fix emphasis, or provide custom instructions for the LLM +- **Edit in your editor** – open the message in `$GIT_EDITOR`, `$VISUAL`, `$EDITOR`, or a sensible fallback (`notepad` on Windows, `nano` elsewhere) +- **Exit** – leave without copying anything if the message isn't ready yet + +This makes it easy to tweak the tone, iterate on suggestions, or fine-tune the final wording before you commit. + ### Use Cases - 📝 Generate commit messages for staged changes @@ -163,19 +175,23 @@ commit . ``` **Set LLM as default** + ```bash Select: Set Default ``` **Change API Key** + ```bash Select: Change API Key ``` **Delete LLM** + ```bash Select: Delete ``` + --- ## Getting API Keys @@ -185,36 +201,31 @@ Select: Delete 1. Visit [Google AI Studio](https://makersuite.google.com/app/apikey) 2. Create a new API key - **Grok (X.AI):** 1. Visit [X.AI Console](https://console.x.ai/) 2. Generate an API key - **Groq:** 1. Sign up at [Groq Cloud](https://console.groq.com/) 2. Create an API key - **Claude (Anthropic):** 1. Visit the [Anthropic Console](https://console.anthropic.com/) 2. Create a new API key - **OpenAI (ChatGPT):** 1. Visit [OpenAI Platform](https://platform.openai.com/api-keys) 2. Create a new API key - **Ollama (Local LLM):** 1. Install Ollama: Visit [Ollama.ai](https://ollama.ai/) and follow installation instructions 2. Start Ollama: `ollama serve` -3. Pull a model: `ollama pull llama3` +3. Pull a model: `ollama pull llama3` 4. Set environment variables: ```bash export COMMIT_LLM=ollama diff --git a/cmd/cli/createMsg.go b/cmd/cli/createMsg.go index 7c062b1..f82a0f9 100644 --- a/cmd/cli/createMsg.go +++ b/cmd/cli/createMsg.go @@ -1,7 +1,12 @@ package cmd import ( + "errors" + "fmt" "os" + "os/exec" + "runtime" + "strings" "github.com/atotto/clipboard" "github.com/dfanso/commit-msg/cmd/cli/store" @@ -15,11 +20,11 @@ import ( "github.com/dfanso/commit-msg/internal/ollama" "github.com/dfanso/commit-msg/internal/stats" "github.com/dfanso/commit-msg/pkg/types" + "github.com/google/shlex" "github.com/pterm/pterm" ) func CreateCommitMsg() { - // Validate COMMIT_LLM and required API keys useLLM, err := store.DefaultLLMKey() if err != nil { @@ -43,32 +48,24 @@ func CreateCommitMsg() { os.Exit(1) } - // Create a minimal config for the API config := &types.Config{ GrokAPI: "https://api.x.ai/v1/chat/completions", } - // Create a repo config for the current directory - repoConfig := types.RepoConfig{ - Path: currentDir, - } + repoConfig := types.RepoConfig{Path: currentDir} - // Get file statistics before fetching changes fileStats, err := stats.GetFileStatistics(&repoConfig) if err != nil { pterm.Error.Printf("Failed to get file statistics: %v\n", err) os.Exit(1) } - // Display header pterm.DefaultHeader.WithFullWidth(). WithBackgroundStyle(pterm.NewStyle(pterm.BgDarkGray)). WithTextStyle(pterm.NewStyle(pterm.FgLightWhite)). Println("Commit Message Generator") pterm.Println() - - // Display file statistics with icons display.ShowFileStatistics(fileStats) if fileStats.TotalFiles == 0 { @@ -80,7 +77,6 @@ func CreateCommitMsg() { return } - // Get the changes changes, err := git.GetChanges(&repoConfig) if err != nil { pterm.Error.Printf("Failed to get Git changes: %v\n", err) @@ -97,8 +93,6 @@ func CreateCommitMsg() { } pterm.Println() - - // Show generating spinner spinnerGenerating, err := pterm.DefaultSpinner. WithSequence("⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"). Start("Generating commit message with " + commitLLM.String() + "...") @@ -107,65 +101,308 @@ func CreateCommitMsg() { os.Exit(1) } - var commitMsg string + attempt := 1 + commitMsg, err := generateMessage(commitLLM, config, changes, apiKey, withAttempt(nil, attempt)) + if err != nil { + spinnerGenerating.Fail("Failed to generate commit message") + displayProviderError(commitLLM, err) + os.Exit(1) + } - switch commitLLM { + spinnerGenerating.Success("Commit message generated successfully!") - case types.ProviderGemini: - commitMsg, err = gemini.GenerateCommitMessage(config, changes, apiKey) + currentMessage := strings.TrimSpace(commitMsg) + currentStyleLabel := stylePresets[0].Label + accepted := false + finalMessage := "" - case types.ProviderOpenAI: - commitMsg, err = chatgpt.GenerateCommitMessage(config, changes, apiKey) +interactionLoop: + for { + pterm.Println() + display.ShowCommitMessage(currentMessage) + action, err := promptActionSelection() + if err != nil { + pterm.Error.Printf("Failed to read selection: %v\n", err) + return + } + + switch action { + case actionAcceptOption: + finalMessage = strings.TrimSpace(currentMessage) + if finalMessage == "" { + pterm.Warning.Println("Commit message is empty; please edit or regenerate before accepting.") + continue + } + if err := clipboard.WriteAll(finalMessage); err != nil { + pterm.Warning.Printf("Could not copy to clipboard: %v\n", err) + } else { + pterm.Success.Println("Commit message copied to clipboard!") + } + accepted = true + break interactionLoop + case actionRegenerateOption: + opts, styleLabel, err := promptStyleSelection(currentStyleLabel) + if errors.Is(err, errSelectionCancelled) { + continue + } + if err != nil { + pterm.Error.Printf("Failed to select style: %v\n", err) + continue + } + if styleLabel != "" { + currentStyleLabel = styleLabel + } + nextAttempt := attempt + 1 + generationOpts := withAttempt(opts, nextAttempt) + spinner, err := pterm.DefaultSpinner. + WithSequence("⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"). + Start(fmt.Sprintf("Regenerating commit message (%s)...", currentStyleLabel)) + if err != nil { + pterm.Error.Printf("Failed to start spinner: %v\n", err) + continue + } + updatedMessage, genErr := generateMessage(commitLLM, config, changes, apiKey, generationOpts) + if genErr != nil { + spinner.Fail("Regeneration failed") + displayProviderError(commitLLM, genErr) + continue + } + spinner.Success("Commit message regenerated!") + attempt = nextAttempt + currentMessage = strings.TrimSpace(updatedMessage) + case actionEditOption: + edited, editErr := editCommitMessage(currentMessage) + if editErr != nil { + pterm.Error.Printf("Failed to edit commit message: %v\n", editErr) + continue + } + if strings.TrimSpace(edited) == "" { + pterm.Warning.Println("Edited commit message is empty; keeping previous message.") + continue + } + currentMessage = strings.TrimSpace(edited) + case actionExitOption: + pterm.Info.Println("Exiting without copying commit message.") + return + default: + pterm.Warning.Printf("Unknown selection: %s\n", action) + } + } + + if !accepted { + return + } + + pterm.Println() + display.ShowChangesPreview(fileStats) +} + +type styleOption struct { + Label string + Instruction string +} + +const ( + actionAcceptOption = "Accept and copy commit message" + actionRegenerateOption = "Regenerate with different tone/style" + actionEditOption = "Edit message in editor" + actionExitOption = "Discard and exit" + customStyleOption = "Custom instructions (enter your own)" + styleBackOption = "Back to actions" +) + +var ( + actionOptions = []string{actionAcceptOption, actionRegenerateOption, actionEditOption, actionExitOption} + stylePresets = []styleOption{ + {Label: "Concise conventional (default)", Instruction: ""}, + {Label: "Detailed summary (adds bullet list)", Instruction: "Produce a conventional commit subject line followed by a blank line and bullet points summarizing the key changes."}, + {Label: "Casual tone", Instruction: "Write the commit message in a friendly, conversational tone while still clearly explaining the changes."}, + {Label: "Bug fix emphasis", Instruction: "Highlight the bug being fixed, reference the root cause when possible, and describe the remedy in the body."}, + } + errSelectionCancelled = errors.New("selection cancelled") +) + +func generateMessage(provider types.LLMProvider, config *types.Config, changes string, apiKey string, opts *types.GenerationOptions) (string, error) { + switch provider { + case types.ProviderGemini: + return gemini.GenerateCommitMessage(config, changes, apiKey, opts) + case types.ProviderOpenAI: + return chatgpt.GenerateCommitMessage(config, changes, apiKey, opts) case types.ProviderClaude: - commitMsg, err = claude.GenerateCommitMessage(config, changes, apiKey) + return claude.GenerateCommitMessage(config, changes, apiKey, opts) case types.ProviderGroq: - commitMsg, err = groq.GenerateCommitMessage(config, changes, apiKey) + return groq.GenerateCommitMessage(config, changes, apiKey, opts) case types.ProviderOllama: - model := "llama3:latest" - - commitMsg, err = ollama.GenerateCommitMessage(config, changes, apiKey, model) + url := apiKey + if strings.TrimSpace(url) == "" { + url = os.Getenv("OLLAMA_URL") + if url == "" { + url = "http://localhost:11434/api/generate" + } + } + model := os.Getenv("OLLAMA_MODEL") + if model == "" { + model = "llama3:latest" + } + return ollama.GenerateCommitMessage(config, changes, url, model, opts) default: - commitMsg, err = grok.GenerateCommitMessage(config, changes, apiKey) + return grok.GenerateCommitMessage(config, changes, apiKey, opts) } +} +func promptActionSelection() (string, error) { + return pterm.DefaultInteractiveSelect. + WithOptions(actionOptions). + WithDefaultOption(actionAcceptOption). + Show() +} + +func promptStyleSelection(currentLabel string) (*types.GenerationOptions, string, error) { + options := make([]string, 0, len(stylePresets)+2) + for _, preset := range stylePresets { + options = append(options, preset.Label) + } + options = append(options, customStyleOption, styleBackOption) + + selector := pterm.DefaultInteractiveSelect.WithOptions(options) + if currentLabel != "" { + selector = selector.WithDefaultOption(currentLabel) + } + + choice, err := selector.Show() if err != nil { - spinnerGenerating.Fail("Failed to generate commit message") - switch commitLLM { - case types.ProviderGemini: - pterm.Error.Printf("Gemini API error. Check your GEMINI_API_KEY environment variable or run: commit llm setup\n") - case types.ProviderOpenAI: - pterm.Error.Printf("OpenAI API error. Check your OPENAI_API_KEY environment variable or run: commit llm setup\n") - case types.ProviderClaude: - pterm.Error.Printf("Claude API error. Check your CLAUDE_API_KEY environment variable or run: commit llm setup\n") - case types.ProviderGroq: - pterm.Error.Printf("Groq API error. Check your GROQ_API_KEY environment variable or run: commit llm setup\n") - case types.ProviderGrok: - pterm.Error.Printf("Grok API error. Check your GROK_API_KEY environment variable or run: commit llm setup\n") - default: - pterm.Error.Printf("LLM API error: %v\n", err) + return nil, currentLabel, err + } + + switch choice { + case styleBackOption: + return nil, currentLabel, errSelectionCancelled + case customStyleOption: + text, err := pterm.DefaultInteractiveTextInput. + WithDefaultText("Describe the tone or style you're looking for"). + Show() + if err != nil { + return nil, currentLabel, err + } + text = strings.TrimSpace(text) + if text == "" { + return nil, currentLabel, errSelectionCancelled + } + return &types.GenerationOptions{StyleInstruction: text}, formatCustomStyleLabel(text), nil + default: + for _, preset := range stylePresets { + if choice == preset.Label { + if strings.TrimSpace(preset.Instruction) == "" { + return nil, preset.Label, nil + } + return &types.GenerationOptions{StyleInstruction: preset.Instruction}, preset.Label, nil + } } - os.Exit(1) } - spinnerGenerating.Success("Commit message generated successfully!") + return nil, currentLabel, nil +} - pterm.Println() +func editCommitMessage(initial string) (string, error) { + command, args, err := resolveEditorCommand() + if err != nil { + return "", err + } - // Display the commit message in a styled panel - display.ShowCommitMessage(commitMsg) + tmpFile, err := os.CreateTemp("", "commit-msg-*.txt") + if err != nil { + return "", err + } + defer os.Remove(tmpFile.Name()) + + if _, err := tmpFile.WriteString(strings.TrimSpace(initial) + "\n"); err != nil { + tmpFile.Close() + return "", err + } + + if err := tmpFile.Close(); err != nil { + return "", err + } + + cmdArgs := append(args, tmpFile.Name()) + cmd := exec.Command(command, cmdArgs...) + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr - // Copy to clipboard - err = clipboard.WriteAll(commitMsg) + if err := cmd.Run(); err != nil { + return "", fmt.Errorf("editor exited with error: %w", err) + } + + content, err := os.ReadFile(tmpFile.Name()) if err != nil { - pterm.Warning.Printf("Could not copy to clipboard: %v\n", err) - } else { - pterm.Success.Println("Commit message copied to clipboard!") + return "", err } - pterm.Println() + return strings.TrimSpace(string(content)), nil +} - // Display changes preview - display.ShowChangesPreview(fileStats) +func resolveEditorCommand() (string, []string, error) { + candidates := []string{ + os.Getenv("GIT_EDITOR"), + os.Getenv("VISUAL"), + os.Getenv("EDITOR"), + } + + for _, candidate := range candidates { + candidate = strings.TrimSpace(candidate) + if candidate == "" { + continue + } + parts, err := shlex.Split(candidate) + if err != nil { + return "", nil, fmt.Errorf("failed to parse editor command %q: %w", candidate, err) + } + if len(parts) == 0 { + continue + } + return parts[0], parts[1:], nil + } + + if runtime.GOOS == "windows" { + return "notepad", nil, nil + } + + return "nano", nil, nil +} + +func formatCustomStyleLabel(instruction string) string { + trimmed := strings.TrimSpace(instruction) + runes := []rune(trimmed) + if len(runes) > 40 { + return fmt.Sprintf("Custom: %s…", string(runes[:37])) + } + return fmt.Sprintf("Custom: %s", trimmed) +} + +func withAttempt(styleOpts *types.GenerationOptions, attempt int) *types.GenerationOptions { + if styleOpts == nil { + return &types.GenerationOptions{Attempt: attempt} + } + clone := *styleOpts + clone.Attempt = attempt + return &clone +} +func displayProviderError(provider types.LLMProvider, err error) { + switch provider { + case types.ProviderGemini: + pterm.Error.Printf("Gemini API error: %v. Check your GEMINI_API_KEY environment variable or run: commit llm setup\n", err) + case types.ProviderOpenAI: + pterm.Error.Printf("OpenAI API error: %v. Check your OPENAI_API_KEY environment variable or run: commit llm setup\n", err) + case types.ProviderClaude: + pterm.Error.Printf("Claude API error: %v. Check your CLAUDE_API_KEY environment variable or run: commit llm setup\n", err) + case types.ProviderGroq: + pterm.Error.Printf("Groq API error: %v. Check your GROQ_API_KEY environment variable or run: commit llm setup\n", err) + case types.ProviderGrok: + pterm.Error.Printf("Grok API error: %v. Check your GROK_API_KEY environment variable or run: commit llm setup\n", err) + default: + pterm.Error.Printf("LLM API error: %v\n", err) + } } diff --git a/go.mod b/go.mod index 99570d6..5860905 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ toolchain go1.24.7 require ( github.com/atotto/clipboard v0.1.4 + github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 github.com/google/generative-ai-go v0.19.0 github.com/manifoldco/promptui v0.9.0 github.com/openai/openai-go/v3 v3.0.1 diff --git a/go.sum b/go.sum index fe4be9b..7507149 100644 --- a/go.sum +++ b/go.sum @@ -58,6 +58,8 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/enterprise-certificate-proxy v0.3.4 h1:XYIDZApgAnrN1c855gTgghdIA6Stxb52D5RnLI1SLyw= diff --git a/internal/chatgpt/chatgpt.go b/internal/chatgpt/chatgpt.go index ba1cf2e..ed554d5 100644 --- a/internal/chatgpt/chatgpt.go +++ b/internal/chatgpt/chatgpt.go @@ -10,11 +10,11 @@ import ( "github.com/dfanso/commit-msg/pkg/types" ) -func GenerateCommitMessage(config *types.Config, changes string, apiKey string) (string, error) { +func GenerateCommitMessage(config *types.Config, changes string, apiKey string, opts *types.GenerationOptions) (string, error) { client := openai.NewClient(option.WithAPIKey(apiKey)) - prompt := fmt.Sprintf("%s\n\n%s", types.CommitPrompt, changes) + prompt := types.BuildCommitPrompt(changes, opts) resp, err := client.Chat.Completions.New(context.TODO(), openai.ChatCompletionNewParams{ Messages: []openai.ChatCompletionMessageParamUnion{ diff --git a/internal/claude/claude.go b/internal/claude/claude.go index 1ed0dc6..b8c740c 100644 --- a/internal/claude/claude.go +++ b/internal/claude/claude.go @@ -30,9 +30,9 @@ type ClaudeResponse struct { } `json:"content"` } -func GenerateCommitMessage(config *types.Config, changes string, apiKey string) (string, error) { +func GenerateCommitMessage(config *types.Config, changes string, apiKey string, opts *types.GenerationOptions) (string, error) { - prompt := fmt.Sprintf("%s\n\n%s", types.CommitPrompt, changes) + prompt := types.BuildCommitPrompt(changes, opts) reqBody := ClaudeRequest{ Model: "claude-3-5-sonnet-20241022", diff --git a/internal/gemini/gemini.go b/internal/gemini/gemini.go index b231868..7670f00 100644 --- a/internal/gemini/gemini.go +++ b/internal/gemini/gemini.go @@ -10,9 +10,9 @@ import ( "github.com/dfanso/commit-msg/pkg/types" ) -func GenerateCommitMessage(config *types.Config, changes string, apiKey string) (string, error) { +func GenerateCommitMessage(config *types.Config, changes string, apiKey string, opts *types.GenerationOptions) (string, error) { // Prepare request to Gemini API - prompt := fmt.Sprintf("%s\n\n%s", types.CommitPrompt, changes) + prompt := types.BuildCommitPrompt(changes, opts) // Create context and client ctx := context.Background() diff --git a/internal/grok/grok.go b/internal/grok/grok.go index a5acc25..d08cc0a 100644 --- a/internal/grok/grok.go +++ b/internal/grok/grok.go @@ -12,9 +12,9 @@ import ( "github.com/dfanso/commit-msg/pkg/types" ) -func GenerateCommitMessage(config *types.Config, changes string, apiKey string) (string, error) { +func GenerateCommitMessage(config *types.Config, changes string, apiKey string, opts *types.GenerationOptions) (string, error) { // Prepare request to X.AI (Grok) API - prompt := fmt.Sprintf("%s\n\n%s", types.CommitPrompt, changes) + prompt := types.BuildCommitPrompt(changes, opts) request := types.GrokRequest{ diff --git a/internal/groq/groq.go b/internal/groq/groq.go index cee46bd..b7a1be4 100644 --- a/internal/groq/groq.go +++ b/internal/groq/groq.go @@ -43,12 +43,12 @@ var ( ) // GenerateCommitMessage calls Groq's OpenAI-compatible chat completions API. -func GenerateCommitMessage(_ *types.Config, changes string, apiKey string) (string, error) { +func GenerateCommitMessage(_ *types.Config, changes string, apiKey string, opts *types.GenerationOptions) (string, error) { if changes == "" { return "", fmt.Errorf("no changes provided for commit message generation") } - prompt := fmt.Sprintf("%s\n\n%s", types.CommitPrompt, changes) + prompt := types.BuildCommitPrompt(changes, opts) model := os.Getenv("GROQ_MODEL") if model == "" { diff --git a/internal/groq/groq_test.go b/internal/groq/groq_test.go index 754e4de..c1e2557 100644 --- a/internal/groq/groq_test.go +++ b/internal/groq/groq_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "strings" "testing" "github.com/dfanso/commit-msg/pkg/types" @@ -73,7 +74,7 @@ func TestGenerateCommitMessageSuccess(t *testing.T) { t.Fatalf("failed to write response: %v", err) } }, func() { - msg, err := GenerateCommitMessage(&types.Config{}, "diff", "test-key") + msg, err := GenerateCommitMessage(&types.Config{}, "diff", "test-key", nil) if err != nil { t.Fatalf("GenerateCommitMessage returned error: %v", err) } @@ -89,7 +90,7 @@ func TestGenerateCommitMessageNonOK(t *testing.T) { withTestServer(t, func(w http.ResponseWriter, r *http.Request) { http.Error(w, `{"error":"bad things"}`, http.StatusBadGateway) }, func() { - _, err := GenerateCommitMessage(&types.Config{}, "changes", "key") + _, err := GenerateCommitMessage(&types.Config{}, "changes", "key", nil) if err == nil { t.Fatal("expected error but got nil") } @@ -100,7 +101,45 @@ func TestGenerateCommitMessageEmptyChanges(t *testing.T) { t.Setenv("GROQ_MODEL", "") t.Setenv("GROQ_API_URL", "") - if _, err := GenerateCommitMessage(&types.Config{}, "", "key"); err == nil { + if _, err := GenerateCommitMessage(&types.Config{}, "", "key", nil); err == nil { t.Fatal("expected error for empty changes") } } + +func TestGenerateCommitMessageIncludesStyleInstruction(t *testing.T) { + recorded := "" + + withTestServer(t, func(w http.ResponseWriter, r *http.Request) { + var payload capturedRequest + if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { + t.Fatalf("failed to decode request: %v", err) + } + recorded = payload.Messages[1].Content + + resp := chatResponse{ + Choices: []chatChoice{ + {Message: chatMessage{Role: "assistant", Content: "feat: add style support"}}, + }, + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(resp); err != nil { + t.Fatalf("failed to write response: %v", err) + } + }, func() { + opts := &types.GenerationOptions{StyleInstruction: "Use a casual tone.", Attempt: 2} + if _, err := GenerateCommitMessage(&types.Config{}, "diff", "key", opts); err != nil { + t.Fatalf("GenerateCommitMessage returned error: %v", err) + } + }) + + if !strings.Contains(recorded, "Additional instructions:") { + t.Fatalf("expected request payload to contain additional instructions, got: %q", recorded) + } + if !strings.Contains(recorded, "Use a casual tone.") { + t.Fatalf("expected request payload to contain custom instruction, got: %q", recorded) + } + if !strings.Contains(recorded, "Regeneration context:") { + t.Fatalf("expected request payload to contain regeneration context, got: %q", recorded) + } +} diff --git a/internal/ollama/ollama.go b/internal/ollama/ollama.go index 2094485..1765f50 100644 --- a/internal/ollama/ollama.go +++ b/internal/ollama/ollama.go @@ -20,14 +20,14 @@ type OllamaResponse struct { Done bool `json:"done"` } -func GenerateCommitMessage(_ *types.Config, changes string, url string, model string) (string, error) { +func GenerateCommitMessage(_ *types.Config, changes string, url string, model string, opts *types.GenerationOptions) (string, error) { // Use llama3:latest as the default model if model == "" { model = "llama3:latest" } // Preparing the prompt - prompt := fmt.Sprintf("%s\n\n%s", types.CommitPrompt, changes) + prompt := types.BuildCommitPrompt(changes, opts) // Generating the request body - add stream: false for non-streaming response reqBody := map[string]interface{}{ diff --git a/pkg/types/options.go b/pkg/types/options.go new file mode 100644 index 0000000..d7ed46d --- /dev/null +++ b/pkg/types/options.go @@ -0,0 +1,10 @@ +package types + +// GenerationOptions controls how commit messages should be produced by LLM providers. +type GenerationOptions struct { + // StyleInstruction contains optional tone/style guidance appended to the base prompt. + StyleInstruction string + // Attempt records the 1-indexed attempt number for this generation request. + // Attempt > 1 signals that the LLM should provide an alternative output. + Attempt int +} diff --git a/pkg/types/prompt.go b/pkg/types/prompt.go index 5b9a2f5..75f546e 100644 --- a/pkg/types/prompt.go +++ b/pkg/types/prompt.go @@ -1,5 +1,10 @@ package types +import ( + "fmt" + "strings" +) + var CommitPrompt = `I need a concise git commit message based on the following changes from my Git repository. Please generate a commit message that: 1. Starts with a verb in the present tense (e.g., "Add", "Fix", "Update", "Feat", "Refactor", etc.) @@ -17,3 +22,28 @@ here is a sample commit msgs: - variable. Also updates dependencies and README.' Here are the changes: ` + +// BuildCommitPrompt constructs the prompt that will be sent to the LLM, applying +// any optional tone/style instructions before appending the repository changes. +func BuildCommitPrompt(changes string, opts *GenerationOptions) string { + var builder strings.Builder + builder.WriteString(CommitPrompt) + + if opts != nil { + if opts.Attempt > 1 { + builder.WriteString("\n\nRegeneration context:\n") + builder.WriteString(fmt.Sprintf("- This is attempt #%d.\n", opts.Attempt)) + builder.WriteString("- Provide a commit message that is meaningfully different from earlier attempts.\n") + } + + if strings.TrimSpace(opts.StyleInstruction) != "" { + builder.WriteString("\n\nAdditional instructions:\n") + builder.WriteString(strings.TrimSpace(opts.StyleInstruction)) + } + } + + builder.WriteString("\n\n") + builder.WriteString(changes) + + return builder.String() +} diff --git a/pkg/types/types_test.go b/pkg/types/types_test.go index 710fff2..0231633 100644 --- a/pkg/types/types_test.go +++ b/pkg/types/types_test.go @@ -48,3 +48,58 @@ func TestCommitPromptContent(t *testing.T) { } } } + +func TestBuildCommitPromptDefault(t *testing.T) { + t.Parallel() + + changes := "diff --git a/main.go b/main.go" + prompt := BuildCommitPrompt(changes, nil) + + if !strings.HasSuffix(prompt, changes) { + t.Fatalf("expected prompt to end with changes, got %q", prompt) + } + + if strings.Contains(prompt, "Additional instructions:") { + t.Fatalf("expected no additional instructions block in default prompt") + } +} + +func TestBuildCommitPromptWithInstructions(t *testing.T) { + t.Parallel() + + changes := "diff --git a/main.go b/main.go" + options := &GenerationOptions{StyleInstruction: "Use a playful tone."} + prompt := BuildCommitPrompt(changes, options) + + if !strings.Contains(prompt, "Additional instructions:") { + t.Fatalf("expected prompt to contain additional instructions block") + } + + if !strings.Contains(prompt, options.StyleInstruction) { + t.Fatalf("expected prompt to include style instruction %q", options.StyleInstruction) + } + + if !strings.HasSuffix(prompt, changes) { + t.Fatalf("expected prompt to end with changes, got %q", prompt) + } +} + +func TestBuildCommitPromptWithAttempt(t *testing.T) { + t.Parallel() + + changes := "diff --git a/main.go b/main.go" + options := &GenerationOptions{Attempt: 3} + prompt := BuildCommitPrompt(changes, options) + + if !strings.Contains(prompt, "Regeneration context:") { + t.Fatalf("expected regeneration context section, got %q", prompt) + } + + if !strings.Contains(prompt, "attempt #3") { + t.Fatalf("expected attempt number to be mentioned, got %q", prompt) + } + + if !strings.HasSuffix(prompt, changes) { + t.Fatalf("expected prompt to end with changes, got %q", prompt) + } +} From ec3d2bc9b29d3a2180de9cc8d79cb85622f637f7 Mon Sep 17 00:00:00 2001 From: Muneer320 Date: Tue, 7 Oct 2025 20:23:14 +0530 Subject: [PATCH 2/5] Improved Docstring Coverage --- cmd/cli/createMsg.go | 2 ++ cmd/cli/llmSetup.go | 4 ++++ cmd/cli/store/store.go | 9 +++++++++ internal/chatgpt/chatgpt.go | 2 ++ internal/claude/claude.go | 4 ++++ internal/gemini/gemini.go | 2 ++ internal/grok/grok.go | 2 ++ internal/ollama/ollama.go | 4 ++++ pkg/types/prompt.go | 2 ++ pkg/types/types.go | 10 ++++++++++ 10 files changed, 41 insertions(+) diff --git a/cmd/cli/createMsg.go b/cmd/cli/createMsg.go index f82a0f9..819ce85 100644 --- a/cmd/cli/createMsg.go +++ b/cmd/cli/createMsg.go @@ -24,6 +24,8 @@ import ( "github.com/pterm/pterm" ) +// CreateCommitMsg launches the interactive flow for reviewing, regenerating, +// editing, and accepting AI-generated commit messages in the current repo. func CreateCommitMsg() { // Validate COMMIT_LLM and required API keys useLLM, err := store.DefaultLLMKey() diff --git a/cmd/cli/llmSetup.go b/cmd/cli/llmSetup.go index f551c46..4ef57e3 100644 --- a/cmd/cli/llmSetup.go +++ b/cmd/cli/llmSetup.go @@ -9,6 +9,8 @@ import ( "github.com/manifoldco/promptui" ) +// SetupLLM walks the user through selecting an LLM provider and storing the +// corresponding API key or endpoint configuration. func SetupLLM() error { providers := types.GetSupportedProviderStrings() @@ -67,6 +69,8 @@ func SetupLLM() error { return nil } +// UpdateLLM lets the user switch defaults, rotate API keys, or delete stored +// LLM provider configurations. func UpdateLLM() error { SavedModels, err := store.ListSavedModels() diff --git a/cmd/cli/store/store.go b/cmd/cli/store/store.go index b18c1c2..07d552a 100644 --- a/cmd/cli/store/store.go +++ b/cmd/cli/store/store.go @@ -1,3 +1,4 @@ +// Package store persists user-selected LLM providers and credentials. package store import ( @@ -11,16 +12,19 @@ import ( "github.com/dfanso/commit-msg/pkg/types" ) +// LLMProvider represents a single stored LLM provider and its credential. type LLMProvider struct { LLM types.LLMProvider `json:"model"` APIKey string `json:"api_key"` } +// Config describes the on-disk structure for all saved LLM providers. type Config struct { Default types.LLMProvider `json:"default"` LLMProviders []LLMProvider `json:"models"` } +// Save persists or updates an LLM provider entry, marking it as the default. func Save(LLMConfig LLMProvider) error { cfg := Config{ @@ -139,6 +143,7 @@ func getConfigPath() (string, error) { } +// DefaultLLMKey returns the currently selected default LLM provider, if any. func DefaultLLMKey() (*LLMProvider, error) { var cfg Config @@ -179,6 +184,7 @@ func DefaultLLMKey() (*LLMProvider, error) { return nil, errors.New("not found default model in config") } +// ListSavedModels loads all persisted LLM provider configurations. func ListSavedModels() (*Config, error) { var cfg Config @@ -211,6 +217,7 @@ func ListSavedModels() (*Config, error) { } +// ChangeDefault updates the default LLM provider selection in the config. func ChangeDefault(Model types.LLMProvider) error { var cfg Config @@ -247,6 +254,7 @@ func ChangeDefault(Model types.LLMProvider) error { return os.WriteFile(configPath, data, 0600) } +// DeleteModel removes the specified provider from the saved configuration. func DeleteModel(Model types.LLMProvider) error { var cfg Config @@ -300,6 +308,7 @@ func DeleteModel(Model types.LLMProvider) error { } } +// UpdateAPIKey rotates the credential for an existing provider entry. func UpdateAPIKey(Model types.LLMProvider, APIKey string) error { var cfg Config diff --git a/internal/chatgpt/chatgpt.go b/internal/chatgpt/chatgpt.go index ed554d5..714c046 100644 --- a/internal/chatgpt/chatgpt.go +++ b/internal/chatgpt/chatgpt.go @@ -10,6 +10,8 @@ import ( "github.com/dfanso/commit-msg/pkg/types" ) +// GenerateCommitMessage calls OpenAI's chat completions API to turn the provided +// repository changes into a polished git commit message. func GenerateCommitMessage(config *types.Config, changes string, apiKey string, opts *types.GenerationOptions) (string, error) { client := openai.NewClient(option.WithAPIKey(apiKey)) diff --git a/internal/claude/claude.go b/internal/claude/claude.go index b8c740c..c85da45 100644 --- a/internal/claude/claude.go +++ b/internal/claude/claude.go @@ -10,17 +10,20 @@ import ( "github.com/dfanso/commit-msg/pkg/types" ) +// ClaudeRequest describes the payload sent to Anthropic's Claude messages API. type ClaudeRequest struct { Model string `json:"model"` Messages []Message `json:"messages"` MaxTokens int `json:"max_tokens"` } +// Message represents a single role/content pair exchanged with Claude. type Message struct { Role string `json:"role"` Content string `json:"content"` } +// ClaudeResponse captures the subset of fields used from Anthropic responses. type ClaudeResponse struct { ID string `json:"id"` Type string `json:"type"` @@ -30,6 +33,7 @@ type ClaudeResponse struct { } `json:"content"` } +// GenerateCommitMessage produces a commit summary using Anthropic's Claude API. func GenerateCommitMessage(config *types.Config, changes string, apiKey string, opts *types.GenerationOptions) (string, error) { prompt := types.BuildCommitPrompt(changes, opts) diff --git a/internal/gemini/gemini.go b/internal/gemini/gemini.go index 7670f00..0606d49 100644 --- a/internal/gemini/gemini.go +++ b/internal/gemini/gemini.go @@ -10,6 +10,8 @@ import ( "github.com/dfanso/commit-msg/pkg/types" ) +// GenerateCommitMessage asks Google Gemini to author a commit message for the +// supplied repository changes and optional style instructions. func GenerateCommitMessage(config *types.Config, changes string, apiKey string, opts *types.GenerationOptions) (string, error) { // Prepare request to Gemini API prompt := types.BuildCommitPrompt(changes, opts) diff --git a/internal/grok/grok.go b/internal/grok/grok.go index d08cc0a..32116d2 100644 --- a/internal/grok/grok.go +++ b/internal/grok/grok.go @@ -12,6 +12,8 @@ import ( "github.com/dfanso/commit-msg/pkg/types" ) +// GenerateCommitMessage calls X.AI's Grok API to create a commit message from +// the provided Git diff and generation options. func GenerateCommitMessage(config *types.Config, changes string, apiKey string, opts *types.GenerationOptions) (string, error) { // Prepare request to X.AI (Grok) API prompt := types.BuildCommitPrompt(changes, opts) diff --git a/internal/ollama/ollama.go b/internal/ollama/ollama.go index 1765f50..cb4f26e 100644 --- a/internal/ollama/ollama.go +++ b/internal/ollama/ollama.go @@ -10,16 +10,20 @@ import ( "github.com/dfanso/commit-msg/pkg/types" ) +// OllamaRequest captures the prompt payload sent to an Ollama HTTP endpoint. type OllamaRequest struct { Model string `json:"model"` Prompt string `json:"prompt"` } +// OllamaResponse represents the non-streaming response from Ollama. type OllamaResponse struct { Response string `json:"response"` Done bool `json:"done"` } +// GenerateCommitMessage uses a locally hosted Ollama model to draft a commit +// message from repository changes and optional style guidance. func GenerateCommitMessage(_ *types.Config, changes string, url string, model string, opts *types.GenerationOptions) (string, error) { // Use llama3:latest as the default model if model == "" { diff --git a/pkg/types/prompt.go b/pkg/types/prompt.go index 75f546e..cd55192 100644 --- a/pkg/types/prompt.go +++ b/pkg/types/prompt.go @@ -5,6 +5,8 @@ import ( "strings" ) +// CommitPrompt is the base instruction template sent to LLM providers before +// appending repository changes and optional style guidance. var CommitPrompt = `I need a concise git commit message based on the following changes from my Git repository. Please generate a commit message that: 1. Starts with a verb in the present tense (e.g., "Add", "Fix", "Update", "Feat", "Refactor", etc.) diff --git a/pkg/types/types.go b/pkg/types/types.go index 2e6f848..7bfdc90 100644 --- a/pkg/types/types.go +++ b/pkg/types/types.go @@ -19,6 +19,8 @@ func (p LLMProvider) IsValid() bool { switch p { case ProviderOpenAI, ProviderClaude, ProviderGemini, ProviderGrok, ProviderGroq, ProviderOllama: return true + // LLMProvider identifies the large language model backend used to author + // commit messages. default: return false } @@ -30,11 +32,14 @@ func GetSupportedProviders() []LLMProvider { ProviderClaude, ProviderGemini, ProviderGrok, + // String returns the string form of the provider identifier. ProviderGroq, ProviderOllama, } } +// IsValid reports whether the provider is part of the supported set. + func GetSupportedProviderStrings() []string { providers := GetSupportedProviders() strings := make([]string, len(providers)) @@ -44,6 +49,8 @@ func GetSupportedProviderStrings() []string { return strings } +// GetSupportedProviders returns all available provider enums. + func ParseLLMProvider(s string) (LLMProvider, bool) { provider := LLMProvider(s) return provider, provider.IsValid() @@ -55,6 +62,8 @@ type Config struct { Repos map[string]RepoConfig `json:"repos"` } +// GetSupportedProviderStrings returns the human-friendly names for providers. + // Repository configuration type RepoConfig struct { Path string `json:"path"` @@ -63,6 +72,7 @@ type RepoConfig struct { // Grok/X.AI API request structure type GrokRequest struct { + // ParseLLMProvider converts a string into an LLMProvider enum when supported. Messages []Message `json:"messages"` Model string `json:"model"` Stream bool `json:"stream"` From 5f65d558d17b4d8c763d6c32df30351434b1384d Mon Sep 17 00:00:00 2001 From: Muneer320 Date: Tue, 7 Oct 2025 20:42:20 +0530 Subject: [PATCH 3/5] Fix: retain custom styles during regeneration --- README.md | 4 ++-- cmd/cli/createMsg.go | 33 +++++++++++++++++++++++++-------- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 49be130..9554dbc 100644 --- a/README.md +++ b/README.md @@ -225,11 +225,11 @@ Select: Delete 1. Install Ollama: Visit [Ollama.ai](https://ollama.ai/) and follow installation instructions 2. Start Ollama: `ollama serve` -3. Pull a model: `ollama pull llama3` +3. Pull a model: `ollama pull llama3.1` 4. Set environment variables: ```bash export COMMIT_LLM=ollama - export OLLAMA_MODEL=llama3 # llama3 by default + export OLLAMA_MODEL=llama3.1 # llama3.1 by default ``` --- diff --git a/cmd/cli/createMsg.go b/cmd/cli/createMsg.go index 819ce85..6d889b0 100644 --- a/cmd/cli/createMsg.go +++ b/cmd/cli/createMsg.go @@ -115,6 +115,7 @@ func CreateCommitMsg() { currentMessage := strings.TrimSpace(commitMsg) currentStyleLabel := stylePresets[0].Label + var currentStyleOpts *types.GenerationOptions accepted := false finalMessage := "" @@ -144,7 +145,7 @@ interactionLoop: accepted = true break interactionLoop case actionRegenerateOption: - opts, styleLabel, err := promptStyleSelection(currentStyleLabel) + opts, styleLabel, err := promptStyleSelection(currentStyleLabel, currentStyleOpts) if errors.Is(err, errSelectionCancelled) { continue } @@ -155,8 +156,9 @@ interactionLoop: if styleLabel != "" { currentStyleLabel = styleLabel } + currentStyleOpts = opts nextAttempt := attempt + 1 - generationOpts := withAttempt(opts, nextAttempt) + generationOpts := withAttempt(currentStyleOpts, nextAttempt) spinner, err := pterm.DefaultSpinner. WithSequence("⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"). Start(fmt.Sprintf("Regenerating commit message (%s)...", currentStyleLabel)) @@ -260,10 +262,17 @@ func promptActionSelection() (string, error) { Show() } -func promptStyleSelection(currentLabel string) (*types.GenerationOptions, string, error) { - options := make([]string, 0, len(stylePresets)+2) +func promptStyleSelection(currentLabel string, currentOpts *types.GenerationOptions) (*types.GenerationOptions, string, error) { + options := make([]string, 0, len(stylePresets)+3) + foundCurrent := false for _, preset := range stylePresets { options = append(options, preset.Label) + if preset.Label == currentLabel { + foundCurrent = true + } + } + if currentOpts != nil && currentLabel != "" && !foundCurrent { + options = append(options, currentLabel) } options = append(options, customStyleOption, styleBackOption) @@ -274,22 +283,22 @@ func promptStyleSelection(currentLabel string) (*types.GenerationOptions, string choice, err := selector.Show() if err != nil { - return nil, currentLabel, err + return currentOpts, currentLabel, err } switch choice { case styleBackOption: - return nil, currentLabel, errSelectionCancelled + return currentOpts, currentLabel, errSelectionCancelled case customStyleOption: text, err := pterm.DefaultInteractiveTextInput. WithDefaultText("Describe the tone or style you're looking for"). Show() if err != nil { - return nil, currentLabel, err + return currentOpts, currentLabel, err } text = strings.TrimSpace(text) if text == "" { - return nil, currentLabel, errSelectionCancelled + return currentOpts, currentLabel, errSelectionCancelled } return &types.GenerationOptions{StyleInstruction: text}, formatCustomStyleLabel(text), nil default: @@ -301,8 +310,16 @@ func promptStyleSelection(currentLabel string) (*types.GenerationOptions, string return &types.GenerationOptions{StyleInstruction: preset.Instruction}, preset.Label, nil } } + if currentOpts != nil && choice == currentLabel { + clone := *currentOpts + return &clone, currentLabel, nil + } } + if currentOpts != nil && currentLabel != "" { + clone := *currentOpts + return &clone, currentLabel, nil + } return nil, currentLabel, nil } From 535b3bc8d1811581ab5409fcd370b4e01c20eac3 Mon Sep 17 00:00:00 2001 From: Muneer320 Date: Tue, 7 Oct 2025 21:00:49 +0530 Subject: [PATCH 4/5] Refactor: reuse shared message type in claude client --- cmd/cli/store/store.go | 17 +++++++++++++++ internal/claude/claude.go | 14 ++++-------- internal/grok/grok.go | 46 ++++++++++++++++++++++++--------------- pkg/types/types.go | 28 ++++++++++++------------ 4 files changed, 63 insertions(+), 42 deletions(-) diff --git a/cmd/cli/store/store.go b/cmd/cli/store/store.go index 07d552a..efa2b98 100644 --- a/cmd/cli/store/store.go +++ b/cmd/cli/store/store.go @@ -244,6 +244,17 @@ func ChangeDefault(Model types.LLMProvider) error { } } + found := false + for _, p := range cfg.LLMProviders { + if p.LLM == Model { + found = true + break + } + } + if !found { + return fmt.Errorf("cannot set default to %s: no saved entry", Model.String()) + } + cfg.Default = Model data, err = json.MarshalIndent(cfg, "", " ") @@ -335,12 +346,18 @@ func UpdateAPIKey(Model types.LLMProvider, APIKey string) error { } } + updated := false for i, p := range cfg.LLMProviders { if p.LLM == Model { cfg.LLMProviders[i].APIKey = APIKey + updated = true } } + if !updated { + return fmt.Errorf("no saved entry for %s to update", Model.String()) + } + data, err = json.MarshalIndent(cfg, "", " ") if err != nil { return err diff --git a/internal/claude/claude.go b/internal/claude/claude.go index c85da45..5a0a02b 100644 --- a/internal/claude/claude.go +++ b/internal/claude/claude.go @@ -12,15 +12,9 @@ import ( // ClaudeRequest describes the payload sent to Anthropic's Claude messages API. type ClaudeRequest struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - MaxTokens int `json:"max_tokens"` -} - -// Message represents a single role/content pair exchanged with Claude. -type Message struct { - Role string `json:"role"` - Content string `json:"content"` + Model string `json:"model"` + Messages []types.Message `json:"messages"` + MaxTokens int `json:"max_tokens"` } // ClaudeResponse captures the subset of fields used from Anthropic responses. @@ -41,7 +35,7 @@ func GenerateCommitMessage(config *types.Config, changes string, apiKey string, reqBody := ClaudeRequest{ Model: "claude-3-5-sonnet-20241022", MaxTokens: 200, - Messages: []Message{ + Messages: []types.Message{ { Role: "user", Content: prompt, diff --git a/internal/grok/grok.go b/internal/grok/grok.go index 32116d2..65ebd4e 100644 --- a/internal/grok/grok.go +++ b/internal/grok/grok.go @@ -5,13 +5,38 @@ import ( "crypto/tls" "encoding/json" "fmt" - "io/ioutil" + "io" "net/http" + "sync" "time" "github.com/dfanso/commit-msg/pkg/types" ) +var ( + grokClientOnce sync.Once + grokClient *http.Client +) + +func getHTTPClient() *http.Client { + grokClientOnce.Do(func() { + transport := &http.Transport{ + TLSHandshakeTimeout: 10 * time.Second, + MaxIdleConns: 10, + IdleConnTimeout: 30 * time.Second, + DisableCompression: true, + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: false, + }, + } + grokClient = &http.Client{ + Timeout: 30 * time.Second, + Transport: transport, + } + }) + return grokClient +} + // GenerateCommitMessage calls X.AI's Grok API to create a commit message from // the provided Git diff and generation options. func GenerateCommitMessage(config *types.Config, changes string, apiKey string, opts *types.GenerationOptions) (string, error) { @@ -46,22 +71,7 @@ func GenerateCommitMessage(config *types.Config, changes string, apiKey string, req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey)) - // Configure HTTP client with improved TLS settings - transport := &http.Transport{ - TLSHandshakeTimeout: 10 * time.Second, - MaxIdleConns: 10, - IdleConnTimeout: 30 * time.Second, - DisableCompression: true, - // Add TLS config to handle server name mismatch - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: false, // Keep this false for security - }, - } - - client := &http.Client{ - Timeout: 30 * time.Second, - Transport: transport, - } + client := getHTTPClient() resp, err := client.Do(req) if err != nil { return "", err @@ -70,7 +80,7 @@ func GenerateCommitMessage(config *types.Config, changes string, apiKey string, // Check response status if resp.StatusCode != http.StatusOK { - bodyBytes, _ := ioutil.ReadAll(resp.Body) + bodyBytes, _ := io.ReadAll(resp.Body) return "", fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(bodyBytes)) } diff --git a/pkg/types/types.go b/pkg/types/types.go index 7bfdc90..8bf7233 100644 --- a/pkg/types/types.go +++ b/pkg/types/types.go @@ -1,5 +1,7 @@ package types +// LLMProvider identifies the large language model backend used to author +// commit messages. type LLMProvider string const ( @@ -11,35 +13,34 @@ const ( ProviderOllama LLMProvider = "Ollama" ) +// String returns the provider identifier as a plain string. func (p LLMProvider) String() string { return string(p) } +// IsValid reports whether the provider is part of the supported set. func (p LLMProvider) IsValid() bool { switch p { case ProviderOpenAI, ProviderClaude, ProviderGemini, ProviderGrok, ProviderGroq, ProviderOllama: return true - // LLMProvider identifies the large language model backend used to author - // commit messages. default: return false } } +// GetSupportedProviders returns all available provider enums. func GetSupportedProviders() []LLMProvider { return []LLMProvider{ ProviderOpenAI, ProviderClaude, ProviderGemini, ProviderGrok, - // String returns the string form of the provider identifier. ProviderGroq, ProviderOllama, } } -// IsValid reports whether the provider is part of the supported set. - +// GetSupportedProviderStrings returns the human-friendly names for providers. func GetSupportedProviderStrings() []string { providers := GetSupportedProviders() strings := make([]string, len(providers)) @@ -49,42 +50,39 @@ func GetSupportedProviderStrings() []string { return strings } -// GetSupportedProviders returns all available provider enums. - +// ParseLLMProvider converts a string into an LLMProvider enum when supported. func ParseLLMProvider(s string) (LLMProvider, bool) { provider := LLMProvider(s) return provider, provider.IsValid() } -// Configuration structure +// Config stores CLI-level configuration including named repositories. type Config struct { GrokAPI string `json:"grok_api"` Repos map[string]RepoConfig `json:"repos"` } -// GetSupportedProviderStrings returns the human-friendly names for providers. - -// Repository configuration +// RepoConfig tracks metadata for a configured Git repository. type RepoConfig struct { Path string `json:"path"` LastRun string `json:"last_run"` } -// Grok/X.AI API request structure +// GrokRequest represents a chat completion request sent to X.AI's API. type GrokRequest struct { - // ParseLLMProvider converts a string into an LLMProvider enum when supported. Messages []Message `json:"messages"` Model string `json:"model"` Stream bool `json:"stream"` Temperature float64 `json:"temperature"` } +// Message captures the role/content pairs exchanged with Grok. type Message struct { Role string `json:"role"` Content string `json:"content"` } -// Grok/X.AI API response structure +// GrokResponse contains the relevant fields parsed from X.AI responses. type GrokResponse struct { Message Message `json:"message,omitempty"` Choices []Choice `json:"choices,omitempty"` @@ -95,12 +93,14 @@ type GrokResponse struct { Usage UsageInfo `json:"usage,omitempty"` } +// Choice details a single response option returned by Grok. type Choice struct { Message Message `json:"message"` Index int `json:"index"` FinishReason string `json:"finish_reason"` } +// UsageInfo reports token usage statistics from Grok responses. type UsageInfo struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` From 9bcb69616856b47523930a5c5386a2976994ab84 Mon Sep 17 00:00:00 2001 From: Muneer320 Date: Tue, 7 Oct 2025 21:01:19 +0530 Subject: [PATCH 5/5] Fix: align ollama defaults with docs --- README.md | 6 +++--- cmd/cli/createMsg.go | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 9554dbc..b26ec5c 100644 --- a/README.md +++ b/README.md @@ -174,19 +174,19 @@ This makes it easy to tweak the tone, iterate on suggestions, or fine-tune the f commit llm update ``` -**Set LLM as default** +### Set LLM as default ```bash Select: Set Default ``` -**Change API Key** +### Change API Key ```bash Select: Change API Key ``` -**Delete LLM** +### Delete LLM ```bash Select: Delete diff --git a/cmd/cli/createMsg.go b/cmd/cli/createMsg.go index 6d889b0..5c08578 100644 --- a/cmd/cli/createMsg.go +++ b/cmd/cli/createMsg.go @@ -247,7 +247,7 @@ func generateMessage(provider types.LLMProvider, config *types.Config, changes s } model := os.Getenv("OLLAMA_MODEL") if model == "" { - model = "llama3:latest" + model = "llama3.1" } return ollama.GenerateCommitMessage(config, changes, url, model, opts) default: