diff --git a/.gitignore b/.gitignore
index a2d044d..b089932 100644
--- a/.gitignore
+++ b/.gitignore
@@ -11,6 +11,9 @@
# Test binary, built with `go test -c`
*.test
+# Test Files
+test_*
+
# Output of the go coverage tool, specifically when used with LiteIDE
*.out
diff --git a/README.md b/README.md
index 3aede34..b26ec5c 100644
--- a/README.md
+++ b/README.md
@@ -31,6 +31,7 @@ Looking to contribute? Check out:
š **Multiple LLM Support** - Choose between Google Gemini, Grok, Claude, ChatGPT, or Ollama (local)
š **Context-Aware** - Analyzes staged and unstaged changes
š **Auto-Copy to Clipboard** - Generated messages are automatically copied for instant use
+šļø **Interactive Review Flow** - Accept, regenerate with new styles, or open the message in your editor before committing
š **File Statistics Display** - Visual preview of changed files and line counts
š **Easy to Use** - Simple CLI interface with beautiful terminal UI
ā” **Fast** - Quick generation of commit messages
@@ -62,7 +63,6 @@ You can use **Google Gemini**, **Grok**, **Claude**, **ChatGPT**, or **Ollama**
echo 'export PATH=$PATH:/path/to/commit-msg' >> ~/.bashrc # or ~/.zshrc
```
-
### Option 2: Build from Source
Requirements: Go 1.23.4 or higher
@@ -105,6 +105,7 @@ go run cmd/commit-msg/main.go .
```bash
commit llm setup
```
+
@@ -113,12 +114,12 @@ go run cmd/commit-msg/main.go .
```bash
commit llm update
```
+
### Example Workflow
-
```bash
# Make changes to your code
echo "console.log('Hello World')" > app.js
@@ -137,6 +138,17 @@ commit .
# You can now paste it with Ctrl+V (or Cmd+V on macOS)
```
+### Interactive Commit Workflow
+
+Once the commit message is generated, the CLI now offers a quick review loop:
+
+- **Accept & copy** ā use the message as-is (it still lands on your clipboard automatically)
+- **Regenerate** ā pick from presets like detailed summaries, casual tone, bug-fix emphasis, or provide custom instructions for the LLM
+- **Edit in your editor** ā open the message in `$GIT_EDITOR`, `$VISUAL`, `$EDITOR`, or a sensible fallback (`notepad` on Windows, `nano` elsewhere)
+- **Exit** ā leave without copying anything if the message isn't ready yet
+
+This makes it easy to tweak the tone, iterate on suggestions, or fine-tune the final wording before you commit.
+
### Use Cases
- š Generate commit messages for staged changes
@@ -162,20 +174,24 @@ commit .
commit llm update
```
-**Set LLM as default**
+### Set LLM as default
+
```bash
Select: Set Default
```
-**Change API Key**
+### Change API Key
+
```bash
Select: Change API Key
```
-**Delete LLM**
+### Delete LLM
+
```bash
Select: Delete
```
+
---
## Getting API Keys
@@ -185,40 +201,35 @@ Select: Delete
1. Visit [Google AI Studio](https://makersuite.google.com/app/apikey)
2. Create a new API key
-
**Grok (X.AI):**
1. Visit [X.AI Console](https://console.x.ai/)
2. Generate an API key
-
**Groq:**
1. Sign up at [Groq Cloud](https://console.groq.com/)
2. Create an API key
-
**Claude (Anthropic):**
1. Visit the [Anthropic Console](https://console.anthropic.com/)
2. Create a new API key
-
**OpenAI (ChatGPT):**
1. Visit [OpenAI Platform](https://platform.openai.com/api-keys)
2. Create a new API key
-
**Ollama (Local LLM):**
1. Install Ollama: Visit [Ollama.ai](https://ollama.ai/) and follow installation instructions
2. Start Ollama: `ollama serve`
-3. Pull a model: `ollama pull llama3`
+3. Pull a model: `ollama pull llama3.1`
4. Set environment variables:
```bash
export COMMIT_LLM=ollama
- export OLLAMA_MODEL=llama3 # llama3 by default
+ export OLLAMA_MODEL=llama3.1 # llama3.1 by default
```
---
diff --git a/cmd/cli/createMsg.go b/cmd/cli/createMsg.go
index 7c062b1..5c08578 100644
--- a/cmd/cli/createMsg.go
+++ b/cmd/cli/createMsg.go
@@ -1,7 +1,12 @@
package cmd
import (
+ "errors"
+ "fmt"
"os"
+ "os/exec"
+ "runtime"
+ "strings"
"github.com/atotto/clipboard"
"github.com/dfanso/commit-msg/cmd/cli/store"
@@ -15,11 +20,13 @@ import (
"github.com/dfanso/commit-msg/internal/ollama"
"github.com/dfanso/commit-msg/internal/stats"
"github.com/dfanso/commit-msg/pkg/types"
+ "github.com/google/shlex"
"github.com/pterm/pterm"
)
+// CreateCommitMsg launches the interactive flow for reviewing, regenerating,
+// editing, and accepting AI-generated commit messages in the current repo.
func CreateCommitMsg() {
-
// Validate COMMIT_LLM and required API keys
useLLM, err := store.DefaultLLMKey()
if err != nil {
@@ -43,32 +50,24 @@ func CreateCommitMsg() {
os.Exit(1)
}
- // Create a minimal config for the API
config := &types.Config{
GrokAPI: "https://api.x.ai/v1/chat/completions",
}
- // Create a repo config for the current directory
- repoConfig := types.RepoConfig{
- Path: currentDir,
- }
+ repoConfig := types.RepoConfig{Path: currentDir}
- // Get file statistics before fetching changes
fileStats, err := stats.GetFileStatistics(&repoConfig)
if err != nil {
pterm.Error.Printf("Failed to get file statistics: %v\n", err)
os.Exit(1)
}
- // Display header
pterm.DefaultHeader.WithFullWidth().
WithBackgroundStyle(pterm.NewStyle(pterm.BgDarkGray)).
WithTextStyle(pterm.NewStyle(pterm.FgLightWhite)).
Println("Commit Message Generator")
pterm.Println()
-
- // Display file statistics with icons
display.ShowFileStatistics(fileStats)
if fileStats.TotalFiles == 0 {
@@ -80,7 +79,6 @@ func CreateCommitMsg() {
return
}
- // Get the changes
changes, err := git.GetChanges(&repoConfig)
if err != nil {
pterm.Error.Printf("Failed to get Git changes: %v\n", err)
@@ -97,8 +95,6 @@ func CreateCommitMsg() {
}
pterm.Println()
-
- // Show generating spinner
spinnerGenerating, err := pterm.DefaultSpinner.
WithSequence("ā ", "ā ", "ā ¹", "ā ø", "ā ¼", "ā “", "ā ¦", "ā §", "ā ", "ā ").
Start("Generating commit message with " + commitLLM.String() + "...")
@@ -107,65 +103,325 @@ func CreateCommitMsg() {
os.Exit(1)
}
- var commitMsg string
+ attempt := 1
+ commitMsg, err := generateMessage(commitLLM, config, changes, apiKey, withAttempt(nil, attempt))
+ if err != nil {
+ spinnerGenerating.Fail("Failed to generate commit message")
+ displayProviderError(commitLLM, err)
+ os.Exit(1)
+ }
- switch commitLLM {
+ spinnerGenerating.Success("Commit message generated successfully!")
- case types.ProviderGemini:
- commitMsg, err = gemini.GenerateCommitMessage(config, changes, apiKey)
+ currentMessage := strings.TrimSpace(commitMsg)
+ currentStyleLabel := stylePresets[0].Label
+ var currentStyleOpts *types.GenerationOptions
+ accepted := false
+ finalMessage := ""
+
+interactionLoop:
+ for {
+ pterm.Println()
+ display.ShowCommitMessage(currentMessage)
+
+ action, err := promptActionSelection()
+ if err != nil {
+ pterm.Error.Printf("Failed to read selection: %v\n", err)
+ return
+ }
- case types.ProviderOpenAI:
- commitMsg, err = chatgpt.GenerateCommitMessage(config, changes, apiKey)
+ switch action {
+ case actionAcceptOption:
+ finalMessage = strings.TrimSpace(currentMessage)
+ if finalMessage == "" {
+ pterm.Warning.Println("Commit message is empty; please edit or regenerate before accepting.")
+ continue
+ }
+ if err := clipboard.WriteAll(finalMessage); err != nil {
+ pterm.Warning.Printf("Could not copy to clipboard: %v\n", err)
+ } else {
+ pterm.Success.Println("Commit message copied to clipboard!")
+ }
+ accepted = true
+ break interactionLoop
+ case actionRegenerateOption:
+ opts, styleLabel, err := promptStyleSelection(currentStyleLabel, currentStyleOpts)
+ if errors.Is(err, errSelectionCancelled) {
+ continue
+ }
+ if err != nil {
+ pterm.Error.Printf("Failed to select style: %v\n", err)
+ continue
+ }
+ if styleLabel != "" {
+ currentStyleLabel = styleLabel
+ }
+ currentStyleOpts = opts
+ nextAttempt := attempt + 1
+ generationOpts := withAttempt(currentStyleOpts, nextAttempt)
+ spinner, err := pterm.DefaultSpinner.
+ WithSequence("ā ", "ā ", "ā ¹", "ā ø", "ā ¼", "ā “", "ā ¦", "ā §", "ā ", "ā ").
+ Start(fmt.Sprintf("Regenerating commit message (%s)...", currentStyleLabel))
+ if err != nil {
+ pterm.Error.Printf("Failed to start spinner: %v\n", err)
+ continue
+ }
+ updatedMessage, genErr := generateMessage(commitLLM, config, changes, apiKey, generationOpts)
+ if genErr != nil {
+ spinner.Fail("Regeneration failed")
+ displayProviderError(commitLLM, genErr)
+ continue
+ }
+ spinner.Success("Commit message regenerated!")
+ attempt = nextAttempt
+ currentMessage = strings.TrimSpace(updatedMessage)
+ case actionEditOption:
+ edited, editErr := editCommitMessage(currentMessage)
+ if editErr != nil {
+ pterm.Error.Printf("Failed to edit commit message: %v\n", editErr)
+ continue
+ }
+ if strings.TrimSpace(edited) == "" {
+ pterm.Warning.Println("Edited commit message is empty; keeping previous message.")
+ continue
+ }
+ currentMessage = strings.TrimSpace(edited)
+ case actionExitOption:
+ pterm.Info.Println("Exiting without copying commit message.")
+ return
+ default:
+ pterm.Warning.Printf("Unknown selection: %s\n", action)
+ }
+ }
+
+ if !accepted {
+ return
+ }
+
+ pterm.Println()
+ display.ShowChangesPreview(fileStats)
+}
+
+type styleOption struct {
+ Label string
+ Instruction string
+}
+
+const (
+ actionAcceptOption = "Accept and copy commit message"
+ actionRegenerateOption = "Regenerate with different tone/style"
+ actionEditOption = "Edit message in editor"
+ actionExitOption = "Discard and exit"
+ customStyleOption = "Custom instructions (enter your own)"
+ styleBackOption = "Back to actions"
+)
+
+var (
+ actionOptions = []string{actionAcceptOption, actionRegenerateOption, actionEditOption, actionExitOption}
+ stylePresets = []styleOption{
+ {Label: "Concise conventional (default)", Instruction: ""},
+ {Label: "Detailed summary (adds bullet list)", Instruction: "Produce a conventional commit subject line followed by a blank line and bullet points summarizing the key changes."},
+ {Label: "Casual tone", Instruction: "Write the commit message in a friendly, conversational tone while still clearly explaining the changes."},
+ {Label: "Bug fix emphasis", Instruction: "Highlight the bug being fixed, reference the root cause when possible, and describe the remedy in the body."},
+ }
+ errSelectionCancelled = errors.New("selection cancelled")
+)
+func generateMessage(provider types.LLMProvider, config *types.Config, changes string, apiKey string, opts *types.GenerationOptions) (string, error) {
+ switch provider {
+ case types.ProviderGemini:
+ return gemini.GenerateCommitMessage(config, changes, apiKey, opts)
+ case types.ProviderOpenAI:
+ return chatgpt.GenerateCommitMessage(config, changes, apiKey, opts)
case types.ProviderClaude:
- commitMsg, err = claude.GenerateCommitMessage(config, changes, apiKey)
+ return claude.GenerateCommitMessage(config, changes, apiKey, opts)
case types.ProviderGroq:
- commitMsg, err = groq.GenerateCommitMessage(config, changes, apiKey)
+ return groq.GenerateCommitMessage(config, changes, apiKey, opts)
case types.ProviderOllama:
- model := "llama3:latest"
-
- commitMsg, err = ollama.GenerateCommitMessage(config, changes, apiKey, model)
+ url := apiKey
+ if strings.TrimSpace(url) == "" {
+ url = os.Getenv("OLLAMA_URL")
+ if url == "" {
+ url = "http://localhost:11434/api/generate"
+ }
+ }
+ model := os.Getenv("OLLAMA_MODEL")
+ if model == "" {
+ model = "llama3.1"
+ }
+ return ollama.GenerateCommitMessage(config, changes, url, model, opts)
default:
- commitMsg, err = grok.GenerateCommitMessage(config, changes, apiKey)
+ return grok.GenerateCommitMessage(config, changes, apiKey, opts)
}
+}
+func promptActionSelection() (string, error) {
+ return pterm.DefaultInteractiveSelect.
+ WithOptions(actionOptions).
+ WithDefaultOption(actionAcceptOption).
+ Show()
+}
+
+func promptStyleSelection(currentLabel string, currentOpts *types.GenerationOptions) (*types.GenerationOptions, string, error) {
+ options := make([]string, 0, len(stylePresets)+3)
+ foundCurrent := false
+ for _, preset := range stylePresets {
+ options = append(options, preset.Label)
+ if preset.Label == currentLabel {
+ foundCurrent = true
+ }
+ }
+ if currentOpts != nil && currentLabel != "" && !foundCurrent {
+ options = append(options, currentLabel)
+ }
+ options = append(options, customStyleOption, styleBackOption)
+
+ selector := pterm.DefaultInteractiveSelect.WithOptions(options)
+ if currentLabel != "" {
+ selector = selector.WithDefaultOption(currentLabel)
+ }
+
+ choice, err := selector.Show()
if err != nil {
- spinnerGenerating.Fail("Failed to generate commit message")
- switch commitLLM {
- case types.ProviderGemini:
- pterm.Error.Printf("Gemini API error. Check your GEMINI_API_KEY environment variable or run: commit llm setup\n")
- case types.ProviderOpenAI:
- pterm.Error.Printf("OpenAI API error. Check your OPENAI_API_KEY environment variable or run: commit llm setup\n")
- case types.ProviderClaude:
- pterm.Error.Printf("Claude API error. Check your CLAUDE_API_KEY environment variable or run: commit llm setup\n")
- case types.ProviderGroq:
- pterm.Error.Printf("Groq API error. Check your GROQ_API_KEY environment variable or run: commit llm setup\n")
- case types.ProviderGrok:
- pterm.Error.Printf("Grok API error. Check your GROK_API_KEY environment variable or run: commit llm setup\n")
- default:
- pterm.Error.Printf("LLM API error: %v\n", err)
+ return currentOpts, currentLabel, err
+ }
+
+ switch choice {
+ case styleBackOption:
+ return currentOpts, currentLabel, errSelectionCancelled
+ case customStyleOption:
+ text, err := pterm.DefaultInteractiveTextInput.
+ WithDefaultText("Describe the tone or style you're looking for").
+ Show()
+ if err != nil {
+ return currentOpts, currentLabel, err
+ }
+ text = strings.TrimSpace(text)
+ if text == "" {
+ return currentOpts, currentLabel, errSelectionCancelled
+ }
+ return &types.GenerationOptions{StyleInstruction: text}, formatCustomStyleLabel(text), nil
+ default:
+ for _, preset := range stylePresets {
+ if choice == preset.Label {
+ if strings.TrimSpace(preset.Instruction) == "" {
+ return nil, preset.Label, nil
+ }
+ return &types.GenerationOptions{StyleInstruction: preset.Instruction}, preset.Label, nil
+ }
+ }
+ if currentOpts != nil && choice == currentLabel {
+ clone := *currentOpts
+ return &clone, currentLabel, nil
}
- os.Exit(1)
}
- spinnerGenerating.Success("Commit message generated successfully!")
+ if currentOpts != nil && currentLabel != "" {
+ clone := *currentOpts
+ return &clone, currentLabel, nil
+ }
+ return nil, currentLabel, nil
+}
- pterm.Println()
+func editCommitMessage(initial string) (string, error) {
+ command, args, err := resolveEditorCommand()
+ if err != nil {
+ return "", err
+ }
- // Display the commit message in a styled panel
- display.ShowCommitMessage(commitMsg)
+ tmpFile, err := os.CreateTemp("", "commit-msg-*.txt")
+ if err != nil {
+ return "", err
+ }
+ defer os.Remove(tmpFile.Name())
+
+ if _, err := tmpFile.WriteString(strings.TrimSpace(initial) + "\n"); err != nil {
+ tmpFile.Close()
+ return "", err
+ }
+
+ if err := tmpFile.Close(); err != nil {
+ return "", err
+ }
- // Copy to clipboard
- err = clipboard.WriteAll(commitMsg)
+ cmdArgs := append(args, tmpFile.Name())
+ cmd := exec.Command(command, cmdArgs...)
+ cmd.Stdin = os.Stdin
+ cmd.Stdout = os.Stdout
+ cmd.Stderr = os.Stderr
+
+ if err := cmd.Run(); err != nil {
+ return "", fmt.Errorf("editor exited with error: %w", err)
+ }
+
+ content, err := os.ReadFile(tmpFile.Name())
if err != nil {
- pterm.Warning.Printf("Could not copy to clipboard: %v\n", err)
- } else {
- pterm.Success.Println("Commit message copied to clipboard!")
+ return "", err
}
- pterm.Println()
+ return strings.TrimSpace(string(content)), nil
+}
- // Display changes preview
- display.ShowChangesPreview(fileStats)
+func resolveEditorCommand() (string, []string, error) {
+ candidates := []string{
+ os.Getenv("GIT_EDITOR"),
+ os.Getenv("VISUAL"),
+ os.Getenv("EDITOR"),
+ }
+
+ for _, candidate := range candidates {
+ candidate = strings.TrimSpace(candidate)
+ if candidate == "" {
+ continue
+ }
+ parts, err := shlex.Split(candidate)
+ if err != nil {
+ return "", nil, fmt.Errorf("failed to parse editor command %q: %w", candidate, err)
+ }
+ if len(parts) == 0 {
+ continue
+ }
+ return parts[0], parts[1:], nil
+ }
+
+ if runtime.GOOS == "windows" {
+ return "notepad", nil, nil
+ }
+ return "nano", nil, nil
+}
+
+func formatCustomStyleLabel(instruction string) string {
+ trimmed := strings.TrimSpace(instruction)
+ runes := []rune(trimmed)
+ if len(runes) > 40 {
+ return fmt.Sprintf("Custom: %sā¦", string(runes[:37]))
+ }
+ return fmt.Sprintf("Custom: %s", trimmed)
+}
+
+func withAttempt(styleOpts *types.GenerationOptions, attempt int) *types.GenerationOptions {
+ if styleOpts == nil {
+ return &types.GenerationOptions{Attempt: attempt}
+ }
+ clone := *styleOpts
+ clone.Attempt = attempt
+ return &clone
+}
+
+func displayProviderError(provider types.LLMProvider, err error) {
+ switch provider {
+ case types.ProviderGemini:
+ pterm.Error.Printf("Gemini API error: %v. Check your GEMINI_API_KEY environment variable or run: commit llm setup\n", err)
+ case types.ProviderOpenAI:
+ pterm.Error.Printf("OpenAI API error: %v. Check your OPENAI_API_KEY environment variable or run: commit llm setup\n", err)
+ case types.ProviderClaude:
+ pterm.Error.Printf("Claude API error: %v. Check your CLAUDE_API_KEY environment variable or run: commit llm setup\n", err)
+ case types.ProviderGroq:
+ pterm.Error.Printf("Groq API error: %v. Check your GROQ_API_KEY environment variable or run: commit llm setup\n", err)
+ case types.ProviderGrok:
+ pterm.Error.Printf("Grok API error: %v. Check your GROK_API_KEY environment variable or run: commit llm setup\n", err)
+ default:
+ pterm.Error.Printf("LLM API error: %v\n", err)
+ }
}
diff --git a/cmd/cli/llmSetup.go b/cmd/cli/llmSetup.go
index f551c46..4ef57e3 100644
--- a/cmd/cli/llmSetup.go
+++ b/cmd/cli/llmSetup.go
@@ -9,6 +9,8 @@ import (
"github.com/manifoldco/promptui"
)
+// SetupLLM walks the user through selecting an LLM provider and storing the
+// corresponding API key or endpoint configuration.
func SetupLLM() error {
providers := types.GetSupportedProviderStrings()
@@ -67,6 +69,8 @@ func SetupLLM() error {
return nil
}
+// UpdateLLM lets the user switch defaults, rotate API keys, or delete stored
+// LLM provider configurations.
func UpdateLLM() error {
SavedModels, err := store.ListSavedModels()
diff --git a/cmd/cli/store/store.go b/cmd/cli/store/store.go
index b18c1c2..efa2b98 100644
--- a/cmd/cli/store/store.go
+++ b/cmd/cli/store/store.go
@@ -1,3 +1,4 @@
+// Package store persists user-selected LLM providers and credentials.
package store
import (
@@ -11,16 +12,19 @@ import (
"github.com/dfanso/commit-msg/pkg/types"
)
+// LLMProvider represents a single stored LLM provider and its credential.
type LLMProvider struct {
LLM types.LLMProvider `json:"model"`
APIKey string `json:"api_key"`
}
+// Config describes the on-disk structure for all saved LLM providers.
type Config struct {
Default types.LLMProvider `json:"default"`
LLMProviders []LLMProvider `json:"models"`
}
+// Save persists or updates an LLM provider entry, marking it as the default.
func Save(LLMConfig LLMProvider) error {
cfg := Config{
@@ -139,6 +143,7 @@ func getConfigPath() (string, error) {
}
+// DefaultLLMKey returns the currently selected default LLM provider, if any.
func DefaultLLMKey() (*LLMProvider, error) {
var cfg Config
@@ -179,6 +184,7 @@ func DefaultLLMKey() (*LLMProvider, error) {
return nil, errors.New("not found default model in config")
}
+// ListSavedModels loads all persisted LLM provider configurations.
func ListSavedModels() (*Config, error) {
var cfg Config
@@ -211,6 +217,7 @@ func ListSavedModels() (*Config, error) {
}
+// ChangeDefault updates the default LLM provider selection in the config.
func ChangeDefault(Model types.LLMProvider) error {
var cfg Config
@@ -237,6 +244,17 @@ func ChangeDefault(Model types.LLMProvider) error {
}
}
+ found := false
+ for _, p := range cfg.LLMProviders {
+ if p.LLM == Model {
+ found = true
+ break
+ }
+ }
+ if !found {
+ return fmt.Errorf("cannot set default to %s: no saved entry", Model.String())
+ }
+
cfg.Default = Model
data, err = json.MarshalIndent(cfg, "", " ")
@@ -247,6 +265,7 @@ func ChangeDefault(Model types.LLMProvider) error {
return os.WriteFile(configPath, data, 0600)
}
+// DeleteModel removes the specified provider from the saved configuration.
func DeleteModel(Model types.LLMProvider) error {
var cfg Config
@@ -300,6 +319,7 @@ func DeleteModel(Model types.LLMProvider) error {
}
}
+// UpdateAPIKey rotates the credential for an existing provider entry.
func UpdateAPIKey(Model types.LLMProvider, APIKey string) error {
var cfg Config
@@ -326,12 +346,18 @@ func UpdateAPIKey(Model types.LLMProvider, APIKey string) error {
}
}
+ updated := false
for i, p := range cfg.LLMProviders {
if p.LLM == Model {
cfg.LLMProviders[i].APIKey = APIKey
+ updated = true
}
}
+ if !updated {
+ return fmt.Errorf("no saved entry for %s to update", Model.String())
+ }
+
data, err = json.MarshalIndent(cfg, "", " ")
if err != nil {
return err
diff --git a/go.mod b/go.mod
index 99570d6..5860905 100644
--- a/go.mod
+++ b/go.mod
@@ -6,6 +6,7 @@ toolchain go1.24.7
require (
github.com/atotto/clipboard v0.1.4
+ github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
github.com/google/generative-ai-go v0.19.0
github.com/manifoldco/promptui v0.9.0
github.com/openai/openai-go/v3 v3.0.1
diff --git a/go.sum b/go.sum
index fe4be9b..7507149 100644
--- a/go.sum
+++ b/go.sum
@@ -58,6 +58,8 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0=
github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM=
+github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
+github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/enterprise-certificate-proxy v0.3.4 h1:XYIDZApgAnrN1c855gTgghdIA6Stxb52D5RnLI1SLyw=
diff --git a/internal/chatgpt/chatgpt.go b/internal/chatgpt/chatgpt.go
index ba1cf2e..714c046 100644
--- a/internal/chatgpt/chatgpt.go
+++ b/internal/chatgpt/chatgpt.go
@@ -10,11 +10,13 @@ import (
"github.com/dfanso/commit-msg/pkg/types"
)
-func GenerateCommitMessage(config *types.Config, changes string, apiKey string) (string, error) {
+// GenerateCommitMessage calls OpenAI's chat completions API to turn the provided
+// repository changes into a polished git commit message.
+func GenerateCommitMessage(config *types.Config, changes string, apiKey string, opts *types.GenerationOptions) (string, error) {
client := openai.NewClient(option.WithAPIKey(apiKey))
- prompt := fmt.Sprintf("%s\n\n%s", types.CommitPrompt, changes)
+ prompt := types.BuildCommitPrompt(changes, opts)
resp, err := client.Chat.Completions.New(context.TODO(), openai.ChatCompletionNewParams{
Messages: []openai.ChatCompletionMessageParamUnion{
diff --git a/internal/claude/claude.go b/internal/claude/claude.go
index 1ed0dc6..5a0a02b 100644
--- a/internal/claude/claude.go
+++ b/internal/claude/claude.go
@@ -10,17 +10,14 @@ import (
"github.com/dfanso/commit-msg/pkg/types"
)
+// ClaudeRequest describes the payload sent to Anthropic's Claude messages API.
type ClaudeRequest struct {
- Model string `json:"model"`
- Messages []Message `json:"messages"`
- MaxTokens int `json:"max_tokens"`
-}
-
-type Message struct {
- Role string `json:"role"`
- Content string `json:"content"`
+ Model string `json:"model"`
+ Messages []types.Message `json:"messages"`
+ MaxTokens int `json:"max_tokens"`
}
+// ClaudeResponse captures the subset of fields used from Anthropic responses.
type ClaudeResponse struct {
ID string `json:"id"`
Type string `json:"type"`
@@ -30,14 +27,15 @@ type ClaudeResponse struct {
} `json:"content"`
}
-func GenerateCommitMessage(config *types.Config, changes string, apiKey string) (string, error) {
+// GenerateCommitMessage produces a commit summary using Anthropic's Claude API.
+func GenerateCommitMessage(config *types.Config, changes string, apiKey string, opts *types.GenerationOptions) (string, error) {
- prompt := fmt.Sprintf("%s\n\n%s", types.CommitPrompt, changes)
+ prompt := types.BuildCommitPrompt(changes, opts)
reqBody := ClaudeRequest{
Model: "claude-3-5-sonnet-20241022",
MaxTokens: 200,
- Messages: []Message{
+ Messages: []types.Message{
{
Role: "user",
Content: prompt,
diff --git a/internal/gemini/gemini.go b/internal/gemini/gemini.go
index b231868..0606d49 100644
--- a/internal/gemini/gemini.go
+++ b/internal/gemini/gemini.go
@@ -10,9 +10,11 @@ import (
"github.com/dfanso/commit-msg/pkg/types"
)
-func GenerateCommitMessage(config *types.Config, changes string, apiKey string) (string, error) {
+// GenerateCommitMessage asks Google Gemini to author a commit message for the
+// supplied repository changes and optional style instructions.
+func GenerateCommitMessage(config *types.Config, changes string, apiKey string, opts *types.GenerationOptions) (string, error) {
// Prepare request to Gemini API
- prompt := fmt.Sprintf("%s\n\n%s", types.CommitPrompt, changes)
+ prompt := types.BuildCommitPrompt(changes, opts)
// Create context and client
ctx := context.Background()
diff --git a/internal/grok/grok.go b/internal/grok/grok.go
index a5acc25..65ebd4e 100644
--- a/internal/grok/grok.go
+++ b/internal/grok/grok.go
@@ -5,16 +5,43 @@ import (
"crypto/tls"
"encoding/json"
"fmt"
- "io/ioutil"
+ "io"
"net/http"
+ "sync"
"time"
"github.com/dfanso/commit-msg/pkg/types"
)
-func GenerateCommitMessage(config *types.Config, changes string, apiKey string) (string, error) {
+var (
+ grokClientOnce sync.Once
+ grokClient *http.Client
+)
+
+func getHTTPClient() *http.Client {
+ grokClientOnce.Do(func() {
+ transport := &http.Transport{
+ TLSHandshakeTimeout: 10 * time.Second,
+ MaxIdleConns: 10,
+ IdleConnTimeout: 30 * time.Second,
+ DisableCompression: true,
+ TLSClientConfig: &tls.Config{
+ InsecureSkipVerify: false,
+ },
+ }
+ grokClient = &http.Client{
+ Timeout: 30 * time.Second,
+ Transport: transport,
+ }
+ })
+ return grokClient
+}
+
+// GenerateCommitMessage calls X.AI's Grok API to create a commit message from
+// the provided Git diff and generation options.
+func GenerateCommitMessage(config *types.Config, changes string, apiKey string, opts *types.GenerationOptions) (string, error) {
// Prepare request to X.AI (Grok) API
- prompt := fmt.Sprintf("%s\n\n%s", types.CommitPrompt, changes)
+ prompt := types.BuildCommitPrompt(changes, opts)
request := types.GrokRequest{
@@ -44,22 +71,7 @@ func GenerateCommitMessage(config *types.Config, changes string, apiKey string)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey))
- // Configure HTTP client with improved TLS settings
- transport := &http.Transport{
- TLSHandshakeTimeout: 10 * time.Second,
- MaxIdleConns: 10,
- IdleConnTimeout: 30 * time.Second,
- DisableCompression: true,
- // Add TLS config to handle server name mismatch
- TLSClientConfig: &tls.Config{
- InsecureSkipVerify: false, // Keep this false for security
- },
- }
-
- client := &http.Client{
- Timeout: 30 * time.Second,
- Transport: transport,
- }
+ client := getHTTPClient()
resp, err := client.Do(req)
if err != nil {
return "", err
@@ -68,7 +80,7 @@ func GenerateCommitMessage(config *types.Config, changes string, apiKey string)
// Check response status
if resp.StatusCode != http.StatusOK {
- bodyBytes, _ := ioutil.ReadAll(resp.Body)
+ bodyBytes, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
}
diff --git a/internal/groq/groq.go b/internal/groq/groq.go
index cee46bd..b7a1be4 100644
--- a/internal/groq/groq.go
+++ b/internal/groq/groq.go
@@ -43,12 +43,12 @@ var (
)
// GenerateCommitMessage calls Groq's OpenAI-compatible chat completions API.
-func GenerateCommitMessage(_ *types.Config, changes string, apiKey string) (string, error) {
+func GenerateCommitMessage(_ *types.Config, changes string, apiKey string, opts *types.GenerationOptions) (string, error) {
if changes == "" {
return "", fmt.Errorf("no changes provided for commit message generation")
}
- prompt := fmt.Sprintf("%s\n\n%s", types.CommitPrompt, changes)
+ prompt := types.BuildCommitPrompt(changes, opts)
model := os.Getenv("GROQ_MODEL")
if model == "" {
diff --git a/internal/groq/groq_test.go b/internal/groq/groq_test.go
index 754e4de..c1e2557 100644
--- a/internal/groq/groq_test.go
+++ b/internal/groq/groq_test.go
@@ -4,6 +4,7 @@ import (
"encoding/json"
"net/http"
"net/http/httptest"
+ "strings"
"testing"
"github.com/dfanso/commit-msg/pkg/types"
@@ -73,7 +74,7 @@ func TestGenerateCommitMessageSuccess(t *testing.T) {
t.Fatalf("failed to write response: %v", err)
}
}, func() {
- msg, err := GenerateCommitMessage(&types.Config{}, "diff", "test-key")
+ msg, err := GenerateCommitMessage(&types.Config{}, "diff", "test-key", nil)
if err != nil {
t.Fatalf("GenerateCommitMessage returned error: %v", err)
}
@@ -89,7 +90,7 @@ func TestGenerateCommitMessageNonOK(t *testing.T) {
withTestServer(t, func(w http.ResponseWriter, r *http.Request) {
http.Error(w, `{"error":"bad things"}`, http.StatusBadGateway)
}, func() {
- _, err := GenerateCommitMessage(&types.Config{}, "changes", "key")
+ _, err := GenerateCommitMessage(&types.Config{}, "changes", "key", nil)
if err == nil {
t.Fatal("expected error but got nil")
}
@@ -100,7 +101,45 @@ func TestGenerateCommitMessageEmptyChanges(t *testing.T) {
t.Setenv("GROQ_MODEL", "")
t.Setenv("GROQ_API_URL", "")
- if _, err := GenerateCommitMessage(&types.Config{}, "", "key"); err == nil {
+ if _, err := GenerateCommitMessage(&types.Config{}, "", "key", nil); err == nil {
t.Fatal("expected error for empty changes")
}
}
+
+func TestGenerateCommitMessageIncludesStyleInstruction(t *testing.T) {
+ recorded := ""
+
+ withTestServer(t, func(w http.ResponseWriter, r *http.Request) {
+ var payload capturedRequest
+ if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
+ t.Fatalf("failed to decode request: %v", err)
+ }
+ recorded = payload.Messages[1].Content
+
+ resp := chatResponse{
+ Choices: []chatChoice{
+ {Message: chatMessage{Role: "assistant", Content: "feat: add style support"}},
+ },
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ if err := json.NewEncoder(w).Encode(resp); err != nil {
+ t.Fatalf("failed to write response: %v", err)
+ }
+ }, func() {
+ opts := &types.GenerationOptions{StyleInstruction: "Use a casual tone.", Attempt: 2}
+ if _, err := GenerateCommitMessage(&types.Config{}, "diff", "key", opts); err != nil {
+ t.Fatalf("GenerateCommitMessage returned error: %v", err)
+ }
+ })
+
+ if !strings.Contains(recorded, "Additional instructions:") {
+ t.Fatalf("expected request payload to contain additional instructions, got: %q", recorded)
+ }
+ if !strings.Contains(recorded, "Use a casual tone.") {
+ t.Fatalf("expected request payload to contain custom instruction, got: %q", recorded)
+ }
+ if !strings.Contains(recorded, "Regeneration context:") {
+ t.Fatalf("expected request payload to contain regeneration context, got: %q", recorded)
+ }
+}
diff --git a/internal/ollama/ollama.go b/internal/ollama/ollama.go
index 2094485..cb4f26e 100644
--- a/internal/ollama/ollama.go
+++ b/internal/ollama/ollama.go
@@ -10,24 +10,28 @@ import (
"github.com/dfanso/commit-msg/pkg/types"
)
+// OllamaRequest captures the prompt payload sent to an Ollama HTTP endpoint.
type OllamaRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
}
+// OllamaResponse represents the non-streaming response from Ollama.
type OllamaResponse struct {
Response string `json:"response"`
Done bool `json:"done"`
}
-func GenerateCommitMessage(_ *types.Config, changes string, url string, model string) (string, error) {
+// GenerateCommitMessage uses a locally hosted Ollama model to draft a commit
+// message from repository changes and optional style guidance.
+func GenerateCommitMessage(_ *types.Config, changes string, url string, model string, opts *types.GenerationOptions) (string, error) {
// Use llama3:latest as the default model
if model == "" {
model = "llama3:latest"
}
// Preparing the prompt
- prompt := fmt.Sprintf("%s\n\n%s", types.CommitPrompt, changes)
+ prompt := types.BuildCommitPrompt(changes, opts)
// Generating the request body - add stream: false for non-streaming response
reqBody := map[string]interface{}{
diff --git a/pkg/types/options.go b/pkg/types/options.go
new file mode 100644
index 0000000..d7ed46d
--- /dev/null
+++ b/pkg/types/options.go
@@ -0,0 +1,10 @@
+package types
+
+// GenerationOptions controls how commit messages should be produced by LLM providers.
+type GenerationOptions struct {
+ // StyleInstruction contains optional tone/style guidance appended to the base prompt.
+ StyleInstruction string
+ // Attempt records the 1-indexed attempt number for this generation request.
+ // Attempt > 1 signals that the LLM should provide an alternative output.
+ Attempt int
+}
diff --git a/pkg/types/prompt.go b/pkg/types/prompt.go
index 5b9a2f5..cd55192 100644
--- a/pkg/types/prompt.go
+++ b/pkg/types/prompt.go
@@ -1,5 +1,12 @@
package types
+import (
+ "fmt"
+ "strings"
+)
+
+// CommitPrompt is the base instruction template sent to LLM providers before
+// appending repository changes and optional style guidance.
var CommitPrompt = `I need a concise git commit message based on the following changes from my Git repository.
Please generate a commit message that:
1. Starts with a verb in the present tense (e.g., "Add", "Fix", "Update", "Feat", "Refactor", etc.)
@@ -17,3 +24,28 @@ here is a sample commit msgs:
- variable. Also updates dependencies and README.'
Here are the changes:
`
+
+// BuildCommitPrompt constructs the prompt that will be sent to the LLM, applying
+// any optional tone/style instructions before appending the repository changes.
+func BuildCommitPrompt(changes string, opts *GenerationOptions) string {
+ var builder strings.Builder
+ builder.WriteString(CommitPrompt)
+
+ if opts != nil {
+ if opts.Attempt > 1 {
+ builder.WriteString("\n\nRegeneration context:\n")
+ builder.WriteString(fmt.Sprintf("- This is attempt #%d.\n", opts.Attempt))
+ builder.WriteString("- Provide a commit message that is meaningfully different from earlier attempts.\n")
+ }
+
+ if strings.TrimSpace(opts.StyleInstruction) != "" {
+ builder.WriteString("\n\nAdditional instructions:\n")
+ builder.WriteString(strings.TrimSpace(opts.StyleInstruction))
+ }
+ }
+
+ builder.WriteString("\n\n")
+ builder.WriteString(changes)
+
+ return builder.String()
+}
diff --git a/pkg/types/types.go b/pkg/types/types.go
index 2e6f848..8bf7233 100644
--- a/pkg/types/types.go
+++ b/pkg/types/types.go
@@ -1,5 +1,7 @@
package types
+// LLMProvider identifies the large language model backend used to author
+// commit messages.
type LLMProvider string
const (
@@ -11,10 +13,12 @@ const (
ProviderOllama LLMProvider = "Ollama"
)
+// String returns the provider identifier as a plain string.
func (p LLMProvider) String() string {
return string(p)
}
+// IsValid reports whether the provider is part of the supported set.
func (p LLMProvider) IsValid() bool {
switch p {
case ProviderOpenAI, ProviderClaude, ProviderGemini, ProviderGrok, ProviderGroq, ProviderOllama:
@@ -24,6 +28,7 @@ func (p LLMProvider) IsValid() bool {
}
}
+// GetSupportedProviders returns all available provider enums.
func GetSupportedProviders() []LLMProvider {
return []LLMProvider{
ProviderOpenAI,
@@ -35,6 +40,7 @@ func GetSupportedProviders() []LLMProvider {
}
}
+// GetSupportedProviderStrings returns the human-friendly names for providers.
func GetSupportedProviderStrings() []string {
providers := GetSupportedProviders()
strings := make([]string, len(providers))
@@ -44,24 +50,25 @@ func GetSupportedProviderStrings() []string {
return strings
}
+// ParseLLMProvider converts a string into an LLMProvider enum when supported.
func ParseLLMProvider(s string) (LLMProvider, bool) {
provider := LLMProvider(s)
return provider, provider.IsValid()
}
-// Configuration structure
+// Config stores CLI-level configuration including named repositories.
type Config struct {
GrokAPI string `json:"grok_api"`
Repos map[string]RepoConfig `json:"repos"`
}
-// Repository configuration
+// RepoConfig tracks metadata for a configured Git repository.
type RepoConfig struct {
Path string `json:"path"`
LastRun string `json:"last_run"`
}
-// Grok/X.AI API request structure
+// GrokRequest represents a chat completion request sent to X.AI's API.
type GrokRequest struct {
Messages []Message `json:"messages"`
Model string `json:"model"`
@@ -69,12 +76,13 @@ type GrokRequest struct {
Temperature float64 `json:"temperature"`
}
+// Message captures the role/content pairs exchanged with Grok.
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
-// Grok/X.AI API response structure
+// GrokResponse contains the relevant fields parsed from X.AI responses.
type GrokResponse struct {
Message Message `json:"message,omitempty"`
Choices []Choice `json:"choices,omitempty"`
@@ -85,12 +93,14 @@ type GrokResponse struct {
Usage UsageInfo `json:"usage,omitempty"`
}
+// Choice details a single response option returned by Grok.
type Choice struct {
Message Message `json:"message"`
Index int `json:"index"`
FinishReason string `json:"finish_reason"`
}
+// UsageInfo reports token usage statistics from Grok responses.
type UsageInfo struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
diff --git a/pkg/types/types_test.go b/pkg/types/types_test.go
index 710fff2..0231633 100644
--- a/pkg/types/types_test.go
+++ b/pkg/types/types_test.go
@@ -48,3 +48,58 @@ func TestCommitPromptContent(t *testing.T) {
}
}
}
+
+func TestBuildCommitPromptDefault(t *testing.T) {
+ t.Parallel()
+
+ changes := "diff --git a/main.go b/main.go"
+ prompt := BuildCommitPrompt(changes, nil)
+
+ if !strings.HasSuffix(prompt, changes) {
+ t.Fatalf("expected prompt to end with changes, got %q", prompt)
+ }
+
+ if strings.Contains(prompt, "Additional instructions:") {
+ t.Fatalf("expected no additional instructions block in default prompt")
+ }
+}
+
+func TestBuildCommitPromptWithInstructions(t *testing.T) {
+ t.Parallel()
+
+ changes := "diff --git a/main.go b/main.go"
+ options := &GenerationOptions{StyleInstruction: "Use a playful tone."}
+ prompt := BuildCommitPrompt(changes, options)
+
+ if !strings.Contains(prompt, "Additional instructions:") {
+ t.Fatalf("expected prompt to contain additional instructions block")
+ }
+
+ if !strings.Contains(prompt, options.StyleInstruction) {
+ t.Fatalf("expected prompt to include style instruction %q", options.StyleInstruction)
+ }
+
+ if !strings.HasSuffix(prompt, changes) {
+ t.Fatalf("expected prompt to end with changes, got %q", prompt)
+ }
+}
+
+func TestBuildCommitPromptWithAttempt(t *testing.T) {
+ t.Parallel()
+
+ changes := "diff --git a/main.go b/main.go"
+ options := &GenerationOptions{Attempt: 3}
+ prompt := BuildCommitPrompt(changes, options)
+
+ if !strings.Contains(prompt, "Regeneration context:") {
+ t.Fatalf("expected regeneration context section, got %q", prompt)
+ }
+
+ if !strings.Contains(prompt, "attempt #3") {
+ t.Fatalf("expected attempt number to be mentioned, got %q", prompt)
+ }
+
+ if !strings.HasSuffix(prompt, changes) {
+ t.Fatalf("expected prompt to end with changes, got %q", prompt)
+ }
+}