diff --git a/cmd/cli/createMsg.go b/cmd/cli/createMsg.go index efd8d96..a303039 100644 --- a/cmd/cli/createMsg.go +++ b/cmd/cli/createMsg.go @@ -8,6 +8,7 @@ import ( "os/exec" "runtime" "strings" + "time" "github.com/atotto/clipboard" "github.com/dfanso/commit-msg/cmd/cli/store" @@ -18,8 +19,13 @@ import ( "github.com/dfanso/commit-msg/pkg/types" "github.com/google/shlex" "github.com/pterm/pterm" + "golang.org/x/time/rate" ) +// Burst once every 5 times per second +// Make the limiter a global variable to better control the rate when it is used. +var apiRateLimiter = rate.NewLimiter(rate.Every(time.Second/5), 5) + // CreateCommitMsg launches the interactive flow for reviewing, regenerating, // editing, and accepting AI-generated commit messages in the current repo. // If dryRun is true, it displays the prompt without making an API call. @@ -294,6 +300,9 @@ func resolveOllamaConfig(apiKey string) (url, model string) { } func generateMessage(ctx context.Context, provider llm.Provider, changes string, opts *types.GenerationOptions) (string, error) { + if err := apiRateLimiter.Wait(ctx); err != nil { + return "", err + } return provider.Generate(ctx, changes, opts) } diff --git a/cmd/cli/createMsg_test.go b/cmd/cli/createMsg_test.go new file mode 100644 index 0000000..263761e --- /dev/null +++ b/cmd/cli/createMsg_test.go @@ -0,0 +1,48 @@ +package cmd + +import ( + "context" + "sync" + "testing" + + "github.com/dfanso/commit-msg/pkg/types" +) + +// FakeProvider implements the llm.Provider interface to simulate API responses +type FakeProvider struct{} + +func (f FakeProvider) Name() types.LLMProvider { return "fake" } + +func (f FakeProvider) Generate(ctx context.Context, changes string, opts *types.GenerationOptions) (string, error) { + return "mock commit message", nil +} + +func TestGenerateMessageRateLimiter(t *testing.T) { + ctx := context.Background() + var waitGroup sync.WaitGroup + successCount := 0 + var mu sync.Mutex + + // Test sending a number of messages in a short period to check the rate limiter + numCalls := 100 + waitGroup.Add(numCalls) + for i := 0; i < numCalls; i++ { + go func() { + defer waitGroup.Done() + _, err := generateMessage(ctx, FakeProvider{}, "", nil) + if err != nil { + t.Logf("rate limiter error: %v", err) + return + } + mu.Lock() + successCount++ + mu.Unlock() + }() + } + waitGroup.Wait() + + t.Logf("Successful calls: %d out of %d", successCount, numCalls) + if successCount != numCalls { + t.Errorf("expected %d successful calls but got %d", numCalls, successCount) + } +}