Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions cmd/cli/createMsg.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"os/exec"
"runtime"
"strings"
"time"

"github.com/atotto/clipboard"
"github.com/dfanso/commit-msg/cmd/cli/store"
Expand All @@ -18,8 +19,13 @@ import (
"github.com/dfanso/commit-msg/pkg/types"
"github.com/google/shlex"
"github.com/pterm/pterm"
"golang.org/x/time/rate"
)

// Burst once every 5 times per second
// Make the limiter a global variable to better control the rate when it is used.
var apiRateLimiter = rate.NewLimiter(rate.Every(time.Second/5), 5)

// CreateCommitMsg launches the interactive flow for reviewing, regenerating,
// editing, and accepting AI-generated commit messages in the current repo.
// If dryRun is true, it displays the prompt without making an API call.
Expand Down Expand Up @@ -294,6 +300,9 @@ func resolveOllamaConfig(apiKey string) (url, model string) {
}

func generateMessage(ctx context.Context, provider llm.Provider, changes string, opts *types.GenerationOptions) (string, error) {
if err := apiRateLimiter.Wait(ctx); err != nil {
return "", err
}
return provider.Generate(ctx, changes, opts)
}

Expand Down
48 changes: 48 additions & 0 deletions cmd/cli/createMsg_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package cmd

import (
"context"
"sync"
"testing"

"github.com/dfanso/commit-msg/pkg/types"
)

// FakeProvider implements the llm.Provider interface to simulate API responses
type FakeProvider struct{}

func (f FakeProvider) Name() types.LLMProvider { return "fake" }

func (f FakeProvider) Generate(ctx context.Context, changes string, opts *types.GenerationOptions) (string, error) {
return "mock commit message", nil
}

func TestGenerateMessageRateLimiter(t *testing.T) {
ctx := context.Background()
var waitGroup sync.WaitGroup
successCount := 0
var mu sync.Mutex

// Test sending a number of messages in a short period to check the rate limiter
numCalls := 100
waitGroup.Add(numCalls)
for i := 0; i < numCalls; i++ {
go func() {
defer waitGroup.Done()
_, err := generateMessage(ctx, FakeProvider{}, "", nil)
if err != nil {
t.Logf("rate limiter error: %v", err)
return
}
mu.Lock()
successCount++
mu.Unlock()
}()
}
waitGroup.Wait()

t.Logf("Successful calls: %d out of %d", successCount, numCalls)
if successCount != numCalls {
t.Errorf("expected %d successful calls but got %d", numCalls, successCount)
}
}
Loading