diff --git a/package.json b/package.json index 6717b67..fb2c832 100644 --- a/package.json +++ b/package.json @@ -49,7 +49,7 @@ "predicteCommit.provider": { "type": "string", "default": "mistral", - "description": "AI provider to use (e.g. 'mistral', 'local'). Matches the registered provider ID.", + "description": "AI provider to use (e.g. 'mistral', 'ollama', 'vllm', 'lmstudio'). Matches the registered provider ID.", "order": 1 }, "predicteCommit.models": { @@ -77,19 +77,19 @@ "predicteCommit.useLocal": { "type": "boolean", "default": false, - "description": "Use local Ollama-compatible endpoint instead of Mistral cloud", + "description": "Use a local provider (Ollama, vLLM, LM Studio) instead of Mistral cloud", "order": 10 }, "predicteCommit.localBaseUrl": { "type": "string", - "default": "http://localhost:11434/v1", - "description": "Base URL for local provider (POST to /chat/completions)", + "default": "", + "description": "Base URL for local provider. Leave empty to use the provider's default (Ollama: :11434, vLLM: :8000, LM Studio: :1234).", "order": 11 }, "predicteCommit.localModel": { "type": "string", - "default": "mistral", - "description": "Model name for local provider", + "default": "", + "description": "Model name for local provider. Leave empty to assume default.", "order": 12 }, "predicteCommit.debugLogging": { diff --git a/src/core/config.ts b/src/core/config.ts index d66e854..42d7607 100644 --- a/src/core/config.ts +++ b/src/core/config.ts @@ -10,6 +10,8 @@ export type PredicteCommitConfig = { debugLogging: boolean; }; +export const DEFAULT_LOCAL_URL = ''; + export function getConfig(): PredicteCommitConfig { const cfg = vscode.workspace.getConfiguration('predicteCommit'); const provider = cfg.get('provider', 'mistral'); @@ -19,8 +21,8 @@ export function getConfig(): PredicteCommitConfig { models: cfg.get('models', []), ignoredFiles: cfg.get('ignoredFiles', ['*-lock.json', '*.svg', 'dist/**']), useLocal, - localBaseUrl: cfg.get('localBaseUrl', 'http://localhost:11434/v1'), - localModel: cfg.get('localModel', 'mistral'), + localBaseUrl: cfg.get('localBaseUrl', DEFAULT_LOCAL_URL), + localModel: cfg.get('localModel', ''), debugLogging: cfg.get('debugLogging', false), }; } @@ -35,7 +37,7 @@ export const TRUNCATION_MARKER = '...TRUNCATED...'; export function getEffectiveProviderId(cfg: PredicteCommitConfig): string { // Backwards compatibility: existing users may rely on useLocal. if (cfg.useLocal) { - return 'local'; + return 'ollama'; } return cfg.provider; } diff --git a/src/providers/local/index.ts b/src/providers/local/index.ts index 0c26f5b..bfe3426 100644 --- a/src/providers/local/index.ts +++ b/src/providers/local/index.ts @@ -1,11 +1,14 @@ import { postChatCompletion } from '../../ai/http'; import type { GenerateRequest, GenerateResult, ProviderClient } from '../../ai/types'; import { registerProvider } from '../../ai/registry'; +import { DEFAULT_LOCAL_URL } from '../../core/config'; export class LocalProvider implements ProviderClient { - readonly id = 'local'; - - constructor(private readonly baseUrl: string, private readonly model: string) {} + constructor( + readonly id: string, + private readonly baseUrl: string, + private readonly model: string + ) {} async generate(req: GenerateRequest): Promise { const url = `${this.baseUrl.replace(/\/$/, '')}/chat/completions`; @@ -20,11 +23,35 @@ export class LocalProvider implements ProviderClient { } } -registerProvider({ - id: 'local', - label: 'Local (Ollama)', - create: async (_context, config) => { +const createLocalProvider = (id: string, defaultUrl: string) => { + return async (_context: any, config: any) => { + let baseUrl = config.localBaseUrl; + // If the user hasn't changed the URL (it's empty), + // then use that provider's default. + // Otherwise (user changed it), use what's in config. + if (baseUrl === DEFAULT_LOCAL_URL) { // DEFAULT_LOCAL_URL is '' + baseUrl = defaultUrl; + } + const model = config.models.length > 0 ? config.models[0] : config.localModel; - return new LocalProvider(config.localBaseUrl, model); - }, + return new LocalProvider(id, baseUrl, model); + }; +}; + +registerProvider({ + id: 'ollama', + label: 'Ollama', + create: createLocalProvider('ollama', 'http://localhost:11434/v1'), +}); + +registerProvider({ + id: 'vllm', + label: 'Local (vLLM)', + create: createLocalProvider('vllm', 'http://localhost:8000/v1'), +}); + +registerProvider({ + id: 'lmstudio', + label: 'Local (LM Studio)', + create: createLocalProvider('lmstudio', 'http://localhost:1234/v1'), }); diff --git a/src/test/config.test.ts b/src/test/config.test.ts index 7880b24..8f0e58e 100644 --- a/src/test/config.test.ts +++ b/src/test/config.test.ts @@ -1,9 +1,9 @@ import * as assert from 'assert'; -import { getEffectiveProviderId, getConfig, PredicteCommitConfig } from '../core/config'; +import { getEffectiveProviderId, getConfig, PredicteCommitConfig, DEFAULT_LOCAL_URL } from '../core/config'; suite('Config Test Suite', () => { test('getEffectiveProviderId', () => { - // Test 1: useLocal is true -> should return 'local' (backward compatibility) + // Test 1: useLocal is true -> should return 'ollama' (renamed from local) const cfgLocal: PredicteCommitConfig = { provider: 'mistral', useLocal: true, @@ -13,7 +13,7 @@ suite('Config Test Suite', () => { localModel: '', debugLogging: false }; - assert.strictEqual(getEffectiveProviderId(cfgLocal), 'local'); + assert.strictEqual(getEffectiveProviderId(cfgLocal), 'ollama'); // Test 2: useLocal is false -> should return provider const cfgMistral: PredicteCommitConfig = { @@ -23,13 +23,13 @@ suite('Config Test Suite', () => { }; assert.strictEqual(getEffectiveProviderId(cfgMistral), 'mistral'); - // Test 3: provider is local (explicit) + // Test 3: provider is ollama (explicit) const cfgLocalExplicit: PredicteCommitConfig = { ...cfgLocal, useLocal: false, - provider: 'local' + provider: 'ollama' }; - assert.strictEqual(getEffectiveProviderId(cfgLocalExplicit), 'local'); + assert.strictEqual(getEffectiveProviderId(cfgLocalExplicit), 'ollama'); }); test('getConfig defaults', () => { @@ -41,5 +41,7 @@ suite('Config Test Suite', () => { assert.strictEqual(cfg.provider, 'mistral'); assert.strictEqual(cfg.useLocal, false); assert.deepStrictEqual(cfg.ignoredFiles, ['*-lock.json', '*.svg', 'dist/**']); + assert.strictEqual(cfg.localBaseUrl, DEFAULT_LOCAL_URL); + assert.strictEqual(cfg.localModel, ''); }); });