Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
"predicteCommit.provider": {
"type": "string",
"default": "mistral",
"description": "AI provider to use (e.g. 'mistral', 'local'). Matches the registered provider ID.",
"description": "AI provider to use (e.g. 'mistral', 'ollama', 'vllm', 'lmstudio'). Matches the registered provider ID.",
"order": 1
},
"predicteCommit.models": {
Expand Down Expand Up @@ -77,19 +77,19 @@
"predicteCommit.useLocal": {
"type": "boolean",
"default": false,
"description": "Use local Ollama-compatible endpoint instead of Mistral cloud",
"description": "Use a local provider (Ollama, vLLM, LM Studio) instead of Mistral cloud",
"order": 10
},
"predicteCommit.localBaseUrl": {
"type": "string",
"default": "http://localhost:11434/v1",
"description": "Base URL for local provider (POST to /chat/completions)",
"default": "",
"description": "Base URL for local provider. Leave empty to use the provider's default (Ollama: :11434, vLLM: :8000, LM Studio: :1234).",
"order": 11
},
"predicteCommit.localModel": {
"type": "string",
"default": "mistral",
"description": "Model name for local provider",
"default": "",
"description": "Model name for local provider. Leave empty to assume default.",
"order": 12
},
"predicteCommit.debugLogging": {
Expand Down
8 changes: 5 additions & 3 deletions src/core/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ export type PredicteCommitConfig = {
debugLogging: boolean;
};

export const DEFAULT_LOCAL_URL = '';

export function getConfig(): PredicteCommitConfig {
const cfg = vscode.workspace.getConfiguration('predicteCommit');
const provider = cfg.get<string>('provider', 'mistral');
Expand All @@ -19,8 +21,8 @@ export function getConfig(): PredicteCommitConfig {
models: cfg.get<string[]>('models', []),
ignoredFiles: cfg.get<string[]>('ignoredFiles', ['*-lock.json', '*.svg', 'dist/**']),
useLocal,
localBaseUrl: cfg.get<string>('localBaseUrl', 'http://localhost:11434/v1'),
localModel: cfg.get<string>('localModel', 'mistral'),
localBaseUrl: cfg.get<string>('localBaseUrl', DEFAULT_LOCAL_URL),
localModel: cfg.get<string>('localModel', ''),
debugLogging: cfg.get<boolean>('debugLogging', false),
};
}
Expand All @@ -35,7 +37,7 @@ export const TRUNCATION_MARKER = '...TRUNCATED...';
export function getEffectiveProviderId(cfg: PredicteCommitConfig): string {
// Backwards compatibility: existing users may rely on useLocal.
if (cfg.useLocal) {
return 'local';
return 'ollama';
}
return cfg.provider;
}
45 changes: 36 additions & 9 deletions src/providers/local/index.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import { postChatCompletion } from '../../ai/http';
import type { GenerateRequest, GenerateResult, ProviderClient } from '../../ai/types';
import { registerProvider } from '../../ai/registry';
import { DEFAULT_LOCAL_URL } from '../../core/config';

export class LocalProvider implements ProviderClient {
readonly id = 'local';

constructor(private readonly baseUrl: string, private readonly model: string) {}
constructor(
readonly id: string,
private readonly baseUrl: string,
private readonly model: string
) {}

async generate(req: GenerateRequest): Promise<GenerateResult> {
const url = `${this.baseUrl.replace(/\/$/, '')}/chat/completions`;
Expand All @@ -20,11 +23,35 @@ export class LocalProvider implements ProviderClient {
}
}

registerProvider({
id: 'local',
label: 'Local (Ollama)',
create: async (_context, config) => {
const createLocalProvider = (id: string, defaultUrl: string) => {
return async (_context: any, config: any) => {
let baseUrl = config.localBaseUrl;
// If the user hasn't changed the URL (it's empty),
// then use that provider's default.
// Otherwise (user changed it), use what's in config.
if (baseUrl === DEFAULT_LOCAL_URL) { // DEFAULT_LOCAL_URL is ''
baseUrl = defaultUrl;
}

const model = config.models.length > 0 ? config.models[0] : config.localModel;
return new LocalProvider(config.localBaseUrl, model);
},
return new LocalProvider(id, baseUrl, model);
};
};

registerProvider({
id: 'ollama',
label: 'Ollama',
create: createLocalProvider('ollama', 'http://localhost:11434/v1'),
});

registerProvider({
id: 'vllm',
label: 'Local (vLLM)',
create: createLocalProvider('vllm', 'http://localhost:8000/v1'),
});

registerProvider({
id: 'lmstudio',
label: 'Local (LM Studio)',
create: createLocalProvider('lmstudio', 'http://localhost:1234/v1'),
});
14 changes: 8 additions & 6 deletions src/test/config.test.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import * as assert from 'assert';
import { getEffectiveProviderId, getConfig, PredicteCommitConfig } from '../core/config';
import { getEffectiveProviderId, getConfig, PredicteCommitConfig, DEFAULT_LOCAL_URL } from '../core/config';

suite('Config Test Suite', () => {
test('getEffectiveProviderId', () => {
// Test 1: useLocal is true -> should return 'local' (backward compatibility)
// Test 1: useLocal is true -> should return 'ollama' (renamed from local)
const cfgLocal: PredicteCommitConfig = {
provider: 'mistral',
useLocal: true,
Expand All @@ -13,7 +13,7 @@ suite('Config Test Suite', () => {
localModel: '',
debugLogging: false
};
assert.strictEqual(getEffectiveProviderId(cfgLocal), 'local');
assert.strictEqual(getEffectiveProviderId(cfgLocal), 'ollama');

// Test 2: useLocal is false -> should return provider
const cfgMistral: PredicteCommitConfig = {
Expand All @@ -23,13 +23,13 @@ suite('Config Test Suite', () => {
};
assert.strictEqual(getEffectiveProviderId(cfgMistral), 'mistral');

// Test 3: provider is local (explicit)
// Test 3: provider is ollama (explicit)
const cfgLocalExplicit: PredicteCommitConfig = {
...cfgLocal,
useLocal: false,
provider: 'local'
provider: 'ollama'
};
assert.strictEqual(getEffectiveProviderId(cfgLocalExplicit), 'local');
assert.strictEqual(getEffectiveProviderId(cfgLocalExplicit), 'ollama');
});

test('getConfig defaults', () => {
Expand All @@ -41,5 +41,7 @@ suite('Config Test Suite', () => {
assert.strictEqual(cfg.provider, 'mistral');
assert.strictEqual(cfg.useLocal, false);
assert.deepStrictEqual(cfg.ignoredFiles, ['*-lock.json', '*.svg', 'dist/**']);
assert.strictEqual(cfg.localBaseUrl, DEFAULT_LOCAL_URL);
assert.strictEqual(cfg.localModel, '');
});
});