diff --git a/package-lock.json b/package-lock.json index f192ac2..88f43da 100644 --- a/package-lock.json +++ b/package-lock.json @@ -47,7 +47,7 @@ "eslint": "^8.57.0", "tsup": "^8.1.0", "tsx": "^4.11.0", - "typescript": "^5.4.0", + "typescript": "^5.9.3", "vitest": "^1.6.0" }, "engines": { diff --git a/package.json b/package.json index 7f02901..ddc4800 100644 --- a/package.json +++ b/package.json @@ -65,7 +65,7 @@ "eslint": "^8.57.0", "tsup": "^8.1.0", "tsx": "^4.11.0", - "typescript": "^5.4.0", + "typescript": "^5.9.3", "vitest": "^1.6.0" }, "engines": { @@ -97,4 +97,4 @@ "bugs": { "url": "https://github.com/d1maash/sortora/issues" } -} \ No newline at end of file +} diff --git a/src/ai/providers/anthropic.ts b/src/ai/providers/anthropic.ts new file mode 100644 index 0000000..0b5fd4d --- /dev/null +++ b/src/ai/providers/anthropic.ts @@ -0,0 +1,137 @@ +/** + * Anthropic (Claude) Provider + * Uses Anthropic API for file classification + */ + +import type { + AIProvider, + ClassificationRequest, + ClassificationResult, + ProviderConfig, +} from './types.js'; +import { buildClassificationPrompt, parseClassificationResponse } from './types.js'; + +export interface AnthropicConfig extends ProviderConfig { + type: 'anthropic'; + apiKey?: string; + baseUrl?: string; + model?: string; +} + +interface AnthropicMessage { + role: 'user' | 'assistant'; + content: string; +} + +interface AnthropicResponse { + id: string; + type: 'message'; + role: 'assistant'; + content: Array<{ + type: 'text'; + text: string; + }>; + stop_reason: string; + usage: { + input_tokens: number; + output_tokens: number; + }; +} + +export class AnthropicProvider implements AIProvider { + readonly name = 'Claude (Anthropic)'; + readonly type = 'anthropic' as const; + + private apiKey: string; + private baseUrl: string; + private model: string; + private initialized = false; + + constructor(config: AnthropicConfig) { + this.apiKey = config.apiKey || ''; + this.baseUrl = config.baseUrl || 'https://api.anthropic.com/v1'; + this.model = config.model || 'claude-3-haiku-20240307'; + } + + isReady(): boolean { + return this.initialized && !!this.apiKey; + } + + async init(): Promise { + if (!this.apiKey) { + throw new Error('Anthropic API key is required'); + } + + // Validate by checking if key format is correct + if (!this.apiKey.startsWith('sk-ant-')) { + throw new Error('Invalid Anthropic API key format'); + } + + this.initialized = true; + } + + async classifyFile(request: ClassificationRequest): Promise { + if (!this.isReady()) { + throw new Error('Anthropic provider is not initialized'); + } + + const prompt = buildClassificationPrompt(request); + + const messages: AnthropicMessage[] = [ + { + role: 'user', + content: prompt, + }, + ]; + + const response = await fetch(`${this.baseUrl}/messages`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'x-api-key': this.apiKey, + 'anthropic-version': '2023-06-01', + }, + body: JSON.stringify({ + model: this.model, + max_tokens: 100, + system: 'You are a file classification assistant. Respond only with valid JSON, no markdown formatting.', + messages, + }), + }); + + if (!response.ok) { + const error = await response.text(); + throw new Error(`Anthropic API error: ${error}`); + } + + const data = await response.json() as AnthropicResponse; + const content = data.content[0]?.text || ''; + + return parseClassificationResponse(content); + } + + async classifyBatch(requests: ClassificationRequest[]): Promise { + // Process in parallel with rate limiting + const results: ClassificationResult[] = []; + const batchSize = 5; + + for (let i = 0; i < requests.length; i += batchSize) { + const batch = requests.slice(i, i + batchSize); + const batchResults = await Promise.all( + batch.map(req => this.classifyFile(req)) + ); + results.push(...batchResults); + + // Small delay between batches to avoid rate limits + if (i + batchSize < requests.length) { + await new Promise(resolve => setTimeout(resolve, 100)); + } + } + + return results; + } + + async dispose(): Promise { + this.initialized = false; + } +} diff --git a/src/ai/providers/gemini.ts b/src/ai/providers/gemini.ts new file mode 100644 index 0000000..1c5250e --- /dev/null +++ b/src/ai/providers/gemini.ts @@ -0,0 +1,152 @@ +/** + * Google Gemini Provider + * Uses Google Generative AI API for file classification + */ + +import type { + AIProvider, + ClassificationRequest, + ClassificationResult, + ProviderConfig, +} from './types.js'; +import { buildClassificationPrompt, parseClassificationResponse } from './types.js'; + +export interface GeminiConfig extends ProviderConfig { + type: 'gemini'; + apiKey?: string; + model?: string; +} + +interface GeminiContent { + parts: Array<{ text: string }>; + role: 'user' | 'model'; +} + +interface GeminiResponse { + candidates: Array<{ + content: { + parts: Array<{ text: string }>; + role: string; + }; + finishReason: string; + }>; + usageMetadata?: { + promptTokenCount: number; + candidatesTokenCount: number; + totalTokenCount: number; + }; +} + +export class GeminiProvider implements AIProvider { + readonly name = 'Gemini (Google)'; + readonly type = 'gemini' as const; + + private apiKey: string; + private model: string; + private initialized = false; + + constructor(config: GeminiConfig) { + this.apiKey = config.apiKey || ''; + this.model = config.model || 'gemini-1.5-flash'; + } + + isReady(): boolean { + return this.initialized && !!this.apiKey; + } + + async init(): Promise { + if (!this.apiKey) { + throw new Error('Gemini API key is required'); + } + + // Validate API key by making a simple request + try { + const response = await fetch( + `https://generativelanguage.googleapis.com/v1beta/models?key=${this.apiKey}` + ); + + if (!response.ok) { + const error = await response.text(); + throw new Error(`Gemini API validation failed: ${error}`); + } + + this.initialized = true; + } catch (error) { + if (error instanceof Error && error.message.includes('fetch')) { + // Network error - assume key is valid + this.initialized = true; + } else { + throw error; + } + } + } + + async classifyFile(request: ClassificationRequest): Promise { + if (!this.isReady()) { + throw new Error('Gemini provider is not initialized'); + } + + const prompt = buildClassificationPrompt(request); + + const contents: GeminiContent[] = [ + { + role: 'user', + parts: [{ text: prompt }], + }, + ]; + + const url = `https://generativelanguage.googleapis.com/v1beta/models/${this.model}:generateContent?key=${this.apiKey}`; + + const response = await fetch(url, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + contents, + generationConfig: { + temperature: 0.1, + maxOutputTokens: 100, + }, + systemInstruction: { + parts: [{ text: 'You are a file classification assistant. Respond only with valid JSON, no markdown formatting.' }], + }, + }), + }); + + if (!response.ok) { + const error = await response.text(); + throw new Error(`Gemini API error: ${error}`); + } + + const data = await response.json() as GeminiResponse; + const content = data.candidates[0]?.content?.parts[0]?.text || ''; + + return parseClassificationResponse(content); + } + + async classifyBatch(requests: ClassificationRequest[]): Promise { + // Process in parallel with rate limiting + const results: ClassificationResult[] = []; + const batchSize = 5; + + for (let i = 0; i < requests.length; i += batchSize) { + const batch = requests.slice(i, i + batchSize); + const batchResults = await Promise.all( + batch.map(req => this.classifyFile(req)) + ); + results.push(...batchResults); + + // Small delay between batches to avoid rate limits + if (i + batchSize < requests.length) { + await new Promise(resolve => setTimeout(resolve, 100)); + } + } + + return results; + } + + async dispose(): Promise { + this.initialized = false; + } +} diff --git a/src/ai/providers/index.ts b/src/ai/providers/index.ts new file mode 100644 index 0000000..b0419a1 --- /dev/null +++ b/src/ai/providers/index.ts @@ -0,0 +1,193 @@ +/** + * AI Provider Factory and Manager + * Creates and manages AI providers based on configuration + */ + +export type { AIProvider, ProviderConfig, ProviderType, ClassificationRequest, ClassificationResult } from './types.js'; +export { buildClassificationPrompt, parseClassificationResponse } from './types.js'; + +export { OpenAIProvider, type OpenAIConfig } from './openai.js'; +export { AnthropicProvider, type AnthropicConfig } from './anthropic.js'; +export { GeminiProvider, type GeminiConfig } from './gemini.js'; +export { OllamaProvider, type OllamaConfig } from './ollama.js'; +export { LocalProvider, type LocalConfig } from './local.js'; + +import type { AIProvider, ProviderType } from './types.js'; +import { OpenAIProvider, type OpenAIConfig } from './openai.js'; +import { AnthropicProvider, type AnthropicConfig } from './anthropic.js'; +import { GeminiProvider, type GeminiConfig } from './gemini.js'; +import { OllamaProvider, type OllamaConfig } from './ollama.js'; +import { LocalProvider, type LocalConfig } from './local.js'; + +export interface ProviderManagerConfig { + provider: ProviderType; + openai?: Omit; + anthropic?: Omit; + gemini?: Omit; + ollama?: Omit; + local?: Omit; +} + +/** + * Create an AI provider based on configuration + */ +export function createProvider(config: ProviderManagerConfig): AIProvider { + switch (config.provider) { + case 'openai': + if (!config.openai?.apiKey) { + throw new Error('OpenAI API key is required. Set OPENAI_API_KEY or configure in settings.'); + } + return new OpenAIProvider({ + type: 'openai', + apiKey: config.openai.apiKey, + baseUrl: config.openai.baseUrl, + model: config.openai.model, + }); + + case 'anthropic': + if (!config.anthropic?.apiKey) { + throw new Error('Anthropic API key is required. Set ANTHROPIC_API_KEY or configure in settings.'); + } + return new AnthropicProvider({ + type: 'anthropic', + apiKey: config.anthropic.apiKey, + baseUrl: config.anthropic.baseUrl, + model: config.anthropic.model, + }); + + case 'gemini': + if (!config.gemini?.apiKey) { + throw new Error('Gemini API key is required. Set GEMINI_API_KEY or configure in settings.'); + } + return new GeminiProvider({ + type: 'gemini', + apiKey: config.gemini.apiKey, + model: config.gemini.model, + }); + + case 'ollama': + return new OllamaProvider({ + type: 'ollama', + baseUrl: config.ollama?.baseUrl, + model: config.ollama?.model, + }); + + case 'local': + if (!config.local?.modelsDir) { + throw new Error('Models directory is required for local provider.'); + } + return new LocalProvider({ + type: 'local', + modelsDir: config.local.modelsDir, + }); + + default: + throw new Error(`Unknown provider type: ${config.provider}`); + } +} + +/** + * Get provider configuration from environment variables + */ +export function getProviderConfigFromEnv(modelsDir: string): ProviderManagerConfig { + const provider = (process.env.SORTORA_AI_PROVIDER as ProviderType) || 'local'; + + return { + provider, + openai: { + apiKey: process.env.OPENAI_API_KEY || '', + baseUrl: process.env.OPENAI_BASE_URL, + model: process.env.OPENAI_MODEL, + }, + anthropic: { + apiKey: process.env.ANTHROPIC_API_KEY || '', + baseUrl: process.env.ANTHROPIC_BASE_URL, + model: process.env.ANTHROPIC_MODEL, + }, + gemini: { + apiKey: process.env.GEMINI_API_KEY || process.env.GOOGLE_API_KEY || '', + model: process.env.GEMINI_MODEL, + }, + ollama: { + baseUrl: process.env.OLLAMA_HOST || process.env.OLLAMA_BASE_URL, + model: process.env.OLLAMA_MODEL, + }, + local: { + modelsDir, + }, + }; +} + +/** + * List all available provider types with descriptions + */ +export function listProviders(): Array<{ type: ProviderType; name: string; description: string }> { + return [ + { + type: 'local', + name: 'Local (MobileBERT)', + description: 'Offline classification using local MobileBERT model (~25 MB)', + }, + { + type: 'openai', + name: 'OpenAI', + description: 'Use OpenAI GPT models (requires API key)', + }, + { + type: 'anthropic', + name: 'Claude (Anthropic)', + description: 'Use Anthropic Claude models (requires API key)', + }, + { + type: 'gemini', + name: 'Gemini (Google)', + description: 'Use Google Gemini models (requires API key)', + }, + { + type: 'ollama', + name: 'Ollama (Local LLM)', + description: 'Use local LLM via Ollama server', + }, + ]; +} + +/** + * Provider Manager class for managing provider lifecycle + */ +export class ProviderManager { + private provider: AIProvider | null = null; + private config: ProviderManagerConfig; + + constructor(config: ProviderManagerConfig) { + this.config = config; + } + + getProviderType(): ProviderType { + return this.config.provider; + } + + async init(): Promise { + if (this.provider && this.provider.isReady()) { + return this.provider; + } + + this.provider = createProvider(this.config); + await this.provider.init(); + return this.provider; + } + + getProvider(): AIProvider | null { + return this.provider; + } + + isReady(): boolean { + return this.provider !== null && this.provider.isReady(); + } + + async dispose(): Promise { + if (this.provider?.dispose) { + await this.provider.dispose(); + } + this.provider = null; + } +} diff --git a/src/ai/providers/local.ts b/src/ai/providers/local.ts new file mode 100644 index 0000000..45eece5 --- /dev/null +++ b/src/ai/providers/local.ts @@ -0,0 +1,77 @@ +/** + * Local Provider + * Wraps the existing local classifier (MobileBERT via @xenova/transformers) + */ + +import type { + AIProvider, + ClassificationRequest, + ClassificationResult, + ProviderConfig, +} from './types.js'; +import { ClassifierService } from '../classifier.js'; + +export interface LocalConfig extends ProviderConfig { + type: 'local'; + modelsDir: string; +} + +export class LocalProvider implements AIProvider { + readonly name = 'Local (MobileBERT)'; + readonly type = 'local' as const; + + private modelsDir: string; + private classifier: ClassifierService | null = null; + private initialized = false; + + constructor(config: LocalConfig) { + this.modelsDir = config.modelsDir; + } + + isReady(): boolean { + return this.initialized && this.classifier !== null; + } + + async init(): Promise { + this.classifier = new ClassifierService(this.modelsDir); + await this.classifier.init(); + this.initialized = true; + } + + async classifyFile(request: ClassificationRequest): Promise { + if (!this.classifier) { + throw new Error('Local provider is not initialized'); + } + + const result = await this.classifier.classifyFile({ + filename: request.filename, + content: request.content, + metadata: request.metadata, + }); + + return { + category: result.category, + confidence: result.confidence, + }; + } + + async classifyBatch(requests: ClassificationRequest[]): Promise { + if (!this.classifier) { + throw new Error('Local provider is not initialized'); + } + + const results: ClassificationResult[] = []; + + for (const request of requests) { + const result = await this.classifyFile(request); + results.push(result); + } + + return results; + } + + async dispose(): Promise { + this.classifier = null; + this.initialized = false; + } +} diff --git a/src/ai/providers/ollama.ts b/src/ai/providers/ollama.ts new file mode 100644 index 0000000..bb3ffaf --- /dev/null +++ b/src/ai/providers/ollama.ts @@ -0,0 +1,142 @@ +/** + * Ollama Provider + * Uses local Ollama server for file classification + */ + +import type { + AIProvider, + ClassificationRequest, + ClassificationResult, + ProviderConfig, +} from './types.js'; +import { buildClassificationPrompt, parseClassificationResponse } from './types.js'; + +export interface OllamaConfig extends ProviderConfig { + type: 'ollama'; + baseUrl?: string; + model?: string; +} + +interface OllamaResponse { + model: string; + created_at: string; + response?: string; + message?: { + role: string; + content: string; + }; + done: boolean; + total_duration?: number; + load_duration?: number; + prompt_eval_count?: number; + eval_count?: number; +} + +export class OllamaProvider implements AIProvider { + readonly name = 'Ollama (Local)'; + readonly type = 'ollama' as const; + + private baseUrl: string; + private model: string; + private initialized = false; + + constructor(config: OllamaConfig) { + this.baseUrl = config.baseUrl || 'http://localhost:11434'; + this.model = config.model || 'llama3.2'; + } + + isReady(): boolean { + return this.initialized; + } + + async init(): Promise { + // Check if Ollama server is running + try { + const response = await fetch(`${this.baseUrl}/api/tags`); + + if (!response.ok) { + throw new Error('Ollama server is not responding'); + } + + const data = await response.json() as { models?: Array<{ name: string }> }; + const models = data.models || []; + + // Check if the specified model is available + const modelExists = models.some((m: { name: string }) => + m.name === this.model || m.name.startsWith(`${this.model}:`) + ); + + if (!modelExists && models.length > 0) { + console.warn(`Model "${this.model}" not found. Available models: ${models.map((m: { name: string }) => m.name).join(', ')}`); + } + + this.initialized = true; + } catch (error) { + if (error instanceof Error) { + if (error.message.includes('ECONNREFUSED') || error.message.includes('fetch')) { + throw new Error('Ollama server is not running. Start it with: ollama serve'); + } + } + throw error; + } + } + + async classifyFile(request: ClassificationRequest): Promise { + if (!this.isReady()) { + throw new Error('Ollama provider is not initialized'); + } + + const prompt = buildClassificationPrompt(request); + + const response = await fetch(`${this.baseUrl}/api/chat`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + model: this.model, + messages: [ + { + role: 'system', + content: 'You are a file classification assistant. Respond only with valid JSON, no markdown formatting.', + }, + { + role: 'user', + content: prompt, + }, + ], + stream: false, + options: { + temperature: 0.1, + num_predict: 100, + }, + }), + }); + + if (!response.ok) { + const error = await response.text(); + throw new Error(`Ollama API error: ${error}`); + } + + const data = await response.json() as OllamaResponse; + const content = data.message?.content || data.response || ''; + + return parseClassificationResponse(content); + } + + async classifyBatch(requests: ClassificationRequest[]): Promise { + // Process sequentially for Ollama (single model loaded at a time) + const results: ClassificationResult[] = []; + + for (const request of requests) { + const result = await this.classifyFile(request); + results.push(result); + } + + return results; + } + + async dispose(): Promise { + this.initialized = false; + } +} diff --git a/src/ai/providers/openai.ts b/src/ai/providers/openai.ts new file mode 100644 index 0000000..8253f49 --- /dev/null +++ b/src/ai/providers/openai.ts @@ -0,0 +1,156 @@ +/** + * OpenAI Provider + * Uses OpenAI API for file classification + */ + +import type { + AIProvider, + ClassificationRequest, + ClassificationResult, + ProviderConfig, +} from './types.js'; +import { buildClassificationPrompt, parseClassificationResponse } from './types.js'; + +export interface OpenAIConfig extends ProviderConfig { + type: 'openai'; + apiKey?: string; + baseUrl?: string; + model?: string; +} + +interface OpenAIMessage { + role: 'system' | 'user' | 'assistant'; + content: string; +} + +interface OpenAIResponse { + id: string; + choices: Array<{ + message: { + content: string; + }; + finish_reason: string; + }>; + usage?: { + prompt_tokens: number; + completion_tokens: number; + total_tokens: number; + }; +} + +export class OpenAIProvider implements AIProvider { + readonly name = 'OpenAI'; + readonly type = 'openai' as const; + + private apiKey: string; + private baseUrl: string; + private model: string; + private initialized = false; + + constructor(config: OpenAIConfig) { + this.apiKey = config.apiKey || ''; + this.baseUrl = config.baseUrl || 'https://api.openai.com/v1'; + this.model = config.model || 'gpt-4o-mini'; + } + + isReady(): boolean { + return this.initialized && !!this.apiKey; + } + + async init(): Promise { + if (!this.apiKey) { + throw new Error('OpenAI API key is required'); + } + + // Validate API key by making a simple request + try { + const response = await fetch(`${this.baseUrl}/models`, { + headers: { + Authorization: `Bearer ${this.apiKey}`, + }, + }); + + if (!response.ok) { + const error = await response.text(); + throw new Error(`OpenAI API validation failed: ${error}`); + } + + this.initialized = true; + } catch (error) { + if (error instanceof Error && error.message.includes('fetch')) { + // Network error - assume key is valid, will fail on actual request + this.initialized = true; + } else { + throw error; + } + } + } + + async classifyFile(request: ClassificationRequest): Promise { + if (!this.isReady()) { + throw new Error('OpenAI provider is not initialized'); + } + + const prompt = buildClassificationPrompt(request); + + const messages: OpenAIMessage[] = [ + { + role: 'system', + content: 'You are a file classification assistant. Respond only with valid JSON.', + }, + { + role: 'user', + content: prompt, + }, + ]; + + const response = await fetch(`${this.baseUrl}/chat/completions`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${this.apiKey}`, + }, + body: JSON.stringify({ + model: this.model, + messages, + temperature: 0.1, + max_tokens: 100, + }), + }); + + if (!response.ok) { + const error = await response.text(); + throw new Error(`OpenAI API error: ${error}`); + } + + const data = await response.json() as OpenAIResponse; + const content = data.choices[0]?.message?.content || ''; + + return parseClassificationResponse(content); + } + + async classifyBatch(requests: ClassificationRequest[]): Promise { + // Process in parallel with rate limiting + const results: ClassificationResult[] = []; + const batchSize = 5; + + for (let i = 0; i < requests.length; i += batchSize) { + const batch = requests.slice(i, i + batchSize); + const batchResults = await Promise.all( + batch.map(req => this.classifyFile(req)) + ); + results.push(...batchResults); + + // Small delay between batches to avoid rate limits + if (i + batchSize < requests.length) { + await new Promise(resolve => setTimeout(resolve, 100)); + } + } + + return results; + } + + async dispose(): Promise { + this.initialized = false; + } +} diff --git a/src/ai/providers/types.ts b/src/ai/providers/types.ts new file mode 100644 index 0000000..54463bb --- /dev/null +++ b/src/ai/providers/types.ts @@ -0,0 +1,142 @@ +/** + * AI Provider Types + * Defines interfaces for external AI providers (OpenAI, Claude, Gemini, etc.) + */ + +import { FILE_CATEGORIES } from '../classifier.js'; + +export type ProviderType = 'local' | 'openai' | 'anthropic' | 'gemini' | 'ollama'; + +export interface ProviderConfig { + type: ProviderType; + apiKey?: string; + baseUrl?: string; + model?: string; +} + +export interface ClassificationRequest { + filename: string; + content?: string; + metadata?: Record; +} + +export interface ClassificationResult { + category: string; + confidence: number; + rawResponse?: string; +} + +export interface AIProvider { + readonly name: string; + readonly type: ProviderType; + + /** + * Check if the provider is properly configured and ready to use + */ + isReady(): boolean; + + /** + * Initialize the provider (load models, validate API keys, etc.) + */ + init(): Promise; + + /** + * Classify a file based on its name, content, and metadata + */ + classifyFile(request: ClassificationRequest): Promise; + + /** + * Classify multiple files in batch (for providers that support it) + */ + classifyBatch?(requests: ClassificationRequest[]): Promise; + + /** + * Clean up resources + */ + dispose?(): Promise; +} + +/** + * Build the classification prompt for external AI providers + */ +export function buildClassificationPrompt(request: ClassificationRequest): string { + const categories = FILE_CATEGORIES.join(', '); + + let prompt = `You are a file classification assistant. Classify the following file into one of these categories: ${categories}. + +File information: +- Filename: ${request.filename}`; + + if (request.content) { + const contentPreview = request.content.slice(0, 500); + prompt += `\n- Content preview: ${contentPreview}`; + } + + if (request.metadata) { + const metaStr = Object.entries(request.metadata) + .filter(([_, v]) => v !== undefined && v !== null) + .slice(0, 5) + .map(([k, v]) => `${k}: ${v}`) + .join(', '); + if (metaStr) { + prompt += `\n- Metadata: ${metaStr}`; + } + } + + prompt += ` + +Respond with ONLY a JSON object in this exact format (no markdown, no explanation): +{"category": "", "confidence": <0.0-1.0>} + +The category must be one of: ${categories}`; + + return prompt; +} + +/** + * Parse the AI response to extract classification result + */ +export function parseClassificationResponse(response: string): ClassificationResult { + // Try to extract JSON from the response + const jsonMatch = response.match(/\{[^}]+\}/); + + if (jsonMatch) { + try { + const parsed = JSON.parse(jsonMatch[0]); + const category = parsed.category?.toLowerCase() || 'unknown or other'; + const confidence = typeof parsed.confidence === 'number' + ? Math.min(1, Math.max(0, parsed.confidence)) + : 0.7; + + // Validate category is in our list + const validCategory = FILE_CATEGORIES.find( + c => c.toLowerCase() === category.toLowerCase() + ) || 'unknown or other'; + + return { + category: validCategory, + confidence, + rawResponse: response, + }; + } catch { + // JSON parsing failed + } + } + + // Fallback: try to find a category in the response + for (const cat of FILE_CATEGORIES) { + if (response.toLowerCase().includes(cat.toLowerCase())) { + return { + category: cat, + confidence: 0.6, + rawResponse: response, + }; + } + } + + return { + category: 'unknown or other', + confidence: 0.3, + rawResponse: response, + }; +} diff --git a/src/cli.ts b/src/cli.ts index f9cee78..47aee06 100644 --- a/src/cli.ts +++ b/src/cli.ts @@ -7,7 +7,8 @@ import inquirer from 'inquirer'; import { resolve } from 'path'; import { existsSync } from 'fs'; -import { VERSION, loadConfig, saveConfig, getAppPaths, ensureDirectories, expandPath } from './config.js'; +import { VERSION, loadConfig, saveConfig, getAppPaths, ensureDirectories, expandPath, getAIProviderConfig, type AIProviderType } from './config.js'; +import { listProviders, type ProviderManagerConfig } from './ai/providers/index.js'; import { Scanner } from './core/scanner.js'; import { Analyzer } from './core/analyzer.js'; import { RuleEngine } from './core/rule-engine.js'; @@ -18,7 +19,6 @@ import { Database } from './storage/database.js'; import { ModelManager } from './ai/model-manager.js'; import { SmartRenamer } from './ai/smart-renamer.js'; import { renameFile } from './actions/rename.js'; -import { renderScanStats, renderFileTable } from './ui/table.js'; import { renderScanStats, renderFileTable, @@ -116,6 +116,7 @@ program .option('-d, --deep', 'Scan recursively') .option('--duplicates', 'Find duplicate files') .option('--ai', 'Use AI for smart classification') + .option('--provider ', 'AI provider to use (local, openai, anthropic, gemini, ollama)') .option('--json', 'Output as JSON') .action(async (targetPath, options) => { const fullPath = resolve(expandPath(targetPath)); @@ -136,12 +137,29 @@ program // Enable AI if requested if (options.ai) { - const aiSpinner = ora('Loading AI models...').start(); + const config = loadConfig(); + const aiConfig = getAIProviderConfig(config); + + // Override provider if specified via CLI + const providerType = (options.provider as AIProviderType) || aiConfig.provider; + + const providerConfig: ProviderManagerConfig = { + provider: providerType, + openai: aiConfig.openai, + anthropic: aiConfig.anthropic, + gemini: aiConfig.gemini, + ollama: aiConfig.ollama, + local: { modelsDir: paths.modelsDir }, + }; + + const providerName = listProviders().find(p => p.type === providerType)?.name || providerType; + const aiSpinner = ora(`Loading AI provider (${providerName})...`).start(); + try { - await analyzer.enableAI(); - aiSpinner.succeed('AI models loaded'); + await analyzer.enableAIWithProvider(providerConfig); + aiSpinner.succeed(`AI provider loaded: ${analyzer.getActiveProviderName()}`); } catch (error) { - aiSpinner.fail('Failed to load AI models. Run "sortora setup" first.'); + aiSpinner.fail(`Failed to load AI provider. ${error instanceof Error ? error.message : ''}`); logger.error('AI error:', error); } } @@ -973,6 +991,190 @@ program } }); +// ═══════════════════════════════════════════════════════════════ +// AI command - manage AI providers +// ═══════════════════════════════════════════════════════════════ +program + .command('ai') + .description('Manage AI providers for classification') + .argument('[action]', 'list, set, test, info') + .argument('[provider]', 'Provider name for set action (openai, anthropic, gemini, ollama, local)') + .action(async (action, provider) => { + const config = loadConfig(); + const paths = getAppPaths(); + const aiConfig = getAIProviderConfig(config); + + if (!action || action === 'list') { + console.log(chalk.bold('\n Available AI Providers:\n')); + + const providers = listProviders(); + for (const p of providers) { + const isActive = p.type === aiConfig.provider; + const marker = isActive ? chalk.green('●') : chalk.dim('○'); + const name = isActive ? chalk.green(p.name) : p.name; + console.log(` ${marker} ${name}`); + console.log(chalk.dim(` ${p.description}`)); + + // Show configuration status + if (p.type === 'openai' && aiConfig.openai?.apiKey) { + console.log(chalk.dim(` API Key: ****${aiConfig.openai.apiKey.slice(-4)}`)); + } + if (p.type === 'anthropic' && aiConfig.anthropic?.apiKey) { + console.log(chalk.dim(` API Key: ****${aiConfig.anthropic.apiKey.slice(-4)}`)); + } + if (p.type === 'gemini' && aiConfig.gemini?.apiKey) { + console.log(chalk.dim(` API Key: ****${aiConfig.gemini.apiKey.slice(-4)}`)); + } + if (p.type === 'ollama') { + console.log(chalk.dim(` URL: ${aiConfig.ollama?.baseUrl || 'http://localhost:11434'}`)); + } + console.log(); + } + + console.log(chalk.dim(' Set provider: sortora ai set ')); + console.log(chalk.dim(' Test provider: sortora ai test\n')); + return; + } + + if (action === 'info') { + console.log(chalk.bold('\n Current AI Configuration:\n')); + console.log(chalk.cyan(` Provider: ${aiConfig.provider}`)); + + if (aiConfig.provider === 'openai') { + console.log(chalk.dim(` Model: ${aiConfig.openai?.model || 'gpt-4o-mini'}`)); + console.log(chalk.dim(` API Key: ${aiConfig.openai?.apiKey ? '****' + aiConfig.openai.apiKey.slice(-4) : 'not set'}`)); + } else if (aiConfig.provider === 'anthropic') { + console.log(chalk.dim(` Model: ${aiConfig.anthropic?.model || 'claude-3-haiku-20240307'}`)); + console.log(chalk.dim(` API Key: ${aiConfig.anthropic?.apiKey ? '****' + aiConfig.anthropic.apiKey.slice(-4) : 'not set'}`)); + } else if (aiConfig.provider === 'gemini') { + console.log(chalk.dim(` Model: ${aiConfig.gemini?.model || 'gemini-1.5-flash'}`)); + console.log(chalk.dim(` API Key: ${aiConfig.gemini?.apiKey ? '****' + aiConfig.gemini.apiKey.slice(-4) : 'not set'}`)); + } else if (aiConfig.provider === 'ollama') { + console.log(chalk.dim(` Model: ${aiConfig.ollama?.model || 'llama3.2'}`)); + console.log(chalk.dim(` URL: ${aiConfig.ollama?.baseUrl || 'http://localhost:11434'}`)); + } else if (aiConfig.provider === 'local') { + console.log(chalk.dim(` Models directory: ${paths.modelsDir}`)); + } + + console.log(chalk.dim('\n Environment variables:')); + console.log(chalk.dim(` SORTORA_AI_PROVIDER: ${process.env.SORTORA_AI_PROVIDER || '(not set)'}`)); + console.log(chalk.dim(` OPENAI_API_KEY: ${process.env.OPENAI_API_KEY ? 'set' : '(not set)'}`)); + console.log(chalk.dim(` ANTHROPIC_API_KEY: ${process.env.ANTHROPIC_API_KEY ? 'set' : '(not set)'}`)); + console.log(chalk.dim(` GEMINI_API_KEY: ${process.env.GEMINI_API_KEY ? 'set' : '(not set)'}`)); + console.log(chalk.dim(` OLLAMA_HOST: ${process.env.OLLAMA_HOST || '(not set)'}\n`)); + return; + } + + if (action === 'set') { + if (!provider) { + console.log(chalk.red('\n Please specify a provider: openai, anthropic, gemini, ollama, local\n')); + return; + } + + const validProviders = ['local', 'openai', 'anthropic', 'gemini', 'ollama']; + if (!validProviders.includes(provider)) { + console.log(chalk.red(`\n Unknown provider: ${provider}`)); + console.log(chalk.dim(` Available: ${validProviders.join(', ')}\n`)); + return; + } + + // Check if API key is required + if (provider === 'openai' && !aiConfig.openai?.apiKey) { + const { apiKey } = await inquirer.prompt([{ + type: 'password', + name: 'apiKey', + message: 'Enter your OpenAI API key:', + mask: '*', + }]); + config.ai.openai.apiKey = apiKey; + } + + if (provider === 'anthropic' && !aiConfig.anthropic?.apiKey) { + const { apiKey } = await inquirer.prompt([{ + type: 'password', + name: 'apiKey', + message: 'Enter your Anthropic API key:', + mask: '*', + }]); + config.ai.anthropic.apiKey = apiKey; + } + + if (provider === 'gemini' && !aiConfig.gemini?.apiKey) { + const { apiKey } = await inquirer.prompt([{ + type: 'password', + name: 'apiKey', + message: 'Enter your Gemini API key:', + mask: '*', + }]); + config.ai.gemini.apiKey = apiKey; + } + + config.ai.provider = provider as AIProviderType; + saveConfig(config); + + console.log(chalk.green(`\n AI provider set to: ${provider}\n`)); + return; + } + + if (action === 'test') { + const spinner = ora('Testing AI provider...').start(); + + try { + const analyzer = new Analyzer(paths.modelsDir); + + const providerConfig: ProviderManagerConfig = { + provider: aiConfig.provider, + openai: aiConfig.openai, + anthropic: aiConfig.anthropic, + gemini: aiConfig.gemini, + ollama: aiConfig.ollama, + local: { modelsDir: paths.modelsDir }, + }; + + await analyzer.enableAIWithProvider(providerConfig); + spinner.succeed(`AI provider initialized: ${analyzer.getActiveProviderName()}`); + + // Test classification + const testSpinner = ora('Testing classification...').start(); + const testResult = await analyzer.classifyWithAI({ + path: '/test/example-document.pdf', + filename: 'quarterly-report-2024.pdf', + extension: 'pdf', + size: 1024, + created: new Date(), + modified: new Date(), + accessed: new Date(), + mimeType: 'application/pdf', + category: 'document', + textContent: 'Q4 2024 Financial Summary. Revenue increased by 15%.', + }); + + testSpinner.succeed(`Classification test passed`); + console.log(chalk.dim(` Category: ${testResult.category}`)); + console.log(chalk.dim(` Confidence: ${Math.round(testResult.confidence * 100)}%\n`)); + } catch (error) { + spinner.fail(`AI provider test failed: ${error instanceof Error ? error.message : 'Unknown error'}`); + console.log(chalk.dim('\n Tips:')); + if (aiConfig.provider === 'openai') { + console.log(chalk.dim(' - Make sure OPENAI_API_KEY is set or configure via "sortora ai set openai"')); + } else if (aiConfig.provider === 'anthropic') { + console.log(chalk.dim(' - Make sure ANTHROPIC_API_KEY is set or configure via "sortora ai set anthropic"')); + } else if (aiConfig.provider === 'gemini') { + console.log(chalk.dim(' - Make sure GEMINI_API_KEY is set or configure via "sortora ai set gemini"')); + } else if (aiConfig.provider === 'ollama') { + console.log(chalk.dim(' - Make sure Ollama is running: ollama serve')); + console.log(chalk.dim(' - Pull a model: ollama pull llama3.2')); + } else if (aiConfig.provider === 'local') { + console.log(chalk.dim(' - Run "sortora setup" to download local models')); + } + console.log(); + } + return; + } + + console.log(chalk.yellow('\n Unknown action. Use: list, set, test, info\n')); + }); + // Show animated banner when run without arguments async function main() { const args = process.argv.slice(2); diff --git a/src/config.ts b/src/config.ts index f8fe3e6..8d917bc 100644 --- a/src/config.ts +++ b/src/config.ts @@ -6,6 +6,29 @@ import { z } from 'zod'; export const VERSION = '1.1.1'; +// AI Provider configuration schema +const AIProviderSchema = z.object({ + provider: z.enum(['local', 'openai', 'anthropic', 'gemini', 'ollama']).default('local'), + openai: z.object({ + apiKey: z.string().optional(), + baseUrl: z.string().optional(), + model: z.string().default('gpt-4o-mini'), + }).default({}), + anthropic: z.object({ + apiKey: z.string().optional(), + baseUrl: z.string().optional(), + model: z.string().default('claude-3-haiku-20240307'), + }).default({}), + gemini: z.object({ + apiKey: z.string().optional(), + model: z.string().default('gemini-1.5-flash'), + }).default({}), + ollama: z.object({ + baseUrl: z.string().default('http://localhost:11434'), + model: z.string().default('llama3.2'), + }).default({}), +}).default({}); + const ConfigSchema = z.object({ version: z.number().default(1), settings: z.object({ @@ -20,6 +43,7 @@ const ConfigSchema = z.object({ 'desktop.ini', ]), }).default({}), + ai: AIProviderSchema, destinations: z.record(z.string()).default({ photos: '~/Pictures/Sorted', screenshots: '~/Pictures/Screenshots', @@ -143,3 +167,82 @@ export function getDestination(config: Config, key: string): string { } export const DEFAULT_CONFIG: Config = ConfigSchema.parse({}); + +export type AIProviderType = 'local' | 'openai' | 'anthropic' | 'gemini' | 'ollama'; + +/** + * Get the AI provider configuration, with environment variable overrides + */ +export function getAIProviderConfig(config: Config): Config['ai'] { + const aiConfig = { ...config.ai }; + + // Environment variable overrides + const envProvider = process.env.SORTORA_AI_PROVIDER as AIProviderType | undefined; + if (envProvider) { + aiConfig.provider = envProvider; + } + + // OpenAI + if (process.env.OPENAI_API_KEY) { + aiConfig.openai = { + ...aiConfig.openai, + apiKey: process.env.OPENAI_API_KEY, + }; + } + if (process.env.OPENAI_BASE_URL) { + aiConfig.openai = { + ...aiConfig.openai, + baseUrl: process.env.OPENAI_BASE_URL, + }; + } + if (process.env.OPENAI_MODEL) { + aiConfig.openai = { + ...aiConfig.openai, + model: process.env.OPENAI_MODEL, + }; + } + + // Anthropic + if (process.env.ANTHROPIC_API_KEY) { + aiConfig.anthropic = { + ...aiConfig.anthropic, + apiKey: process.env.ANTHROPIC_API_KEY, + }; + } + if (process.env.ANTHROPIC_MODEL) { + aiConfig.anthropic = { + ...aiConfig.anthropic, + model: process.env.ANTHROPIC_MODEL, + }; + } + + // Gemini + if (process.env.GEMINI_API_KEY || process.env.GOOGLE_API_KEY) { + aiConfig.gemini = { + ...aiConfig.gemini, + apiKey: process.env.GEMINI_API_KEY || process.env.GOOGLE_API_KEY, + }; + } + if (process.env.GEMINI_MODEL) { + aiConfig.gemini = { + ...aiConfig.gemini, + model: process.env.GEMINI_MODEL, + }; + } + + // Ollama + if (process.env.OLLAMA_HOST || process.env.OLLAMA_BASE_URL) { + aiConfig.ollama = { + ...aiConfig.ollama, + baseUrl: process.env.OLLAMA_HOST || process.env.OLLAMA_BASE_URL || aiConfig.ollama.baseUrl, + }; + } + if (process.env.OLLAMA_MODEL) { + aiConfig.ollama = { + ...aiConfig.ollama, + model: process.env.OLLAMA_MODEL, + }; + } + + return aiConfig; +} diff --git a/src/core/analyzer.ts b/src/core/analyzer.ts index 8117f9a..7bdb134 100644 --- a/src/core/analyzer.ts +++ b/src/core/analyzer.ts @@ -5,6 +5,12 @@ import { hashFileQuick } from '../utils/file-hash.js'; import { analyzeByType, type FileMetadata } from '../analyzers/index.js'; import { EmbeddingService } from '../ai/embeddings.js'; import { ClassifierService } from '../ai/classifier.js'; +import { + type AIProvider, + type ProviderType, + ProviderManager, + type ProviderManagerConfig, +} from '../ai/providers/index.js'; import type { ScanResult } from './scanner.js'; export interface FileAnalysis { @@ -25,27 +31,64 @@ export interface FileAnalysis { aiConfidence?: number; } +export interface AnalyzerOptions { + providerConfig?: ProviderManagerConfig; +} + export class Analyzer { private modelsDir: string; private aiEnabled = false; private embeddingService: EmbeddingService | null = null; private classifierService: ClassifierService | null = null; + private providerManager: ProviderManager | null = null; + private activeProvider: AIProvider | null = null; constructor(modelsDir: string) { this.modelsDir = modelsDir; } + /** + * Enable AI with the default local provider + */ async enableAI(): Promise { - if (this.aiEnabled) return; + return this.enableAIWithProvider({ + provider: 'local', + local: { modelsDir: this.modelsDir }, + }); + } + + /** + * Enable AI with a specific provider configuration + */ + async enableAIWithProvider(config: ProviderManagerConfig): Promise { + if (this.aiEnabled && this.providerManager?.getProviderType() === config.provider) { + return; + } + + // Dispose previous provider if exists + if (this.providerManager) { + await this.providerManager.dispose(); + } + + // Ensure local config has modelsDir + if (config.provider === 'local' && !config.local?.modelsDir) { + config.local = { modelsDir: this.modelsDir }; + } + + this.providerManager = new ProviderManager(config); + this.activeProvider = await this.providerManager.init(); - this.embeddingService = new EmbeddingService(this.modelsDir); - this.classifierService = new ClassifierService(this.modelsDir); + // Also initialize embedding service for similarity search + if (!this.embeddingService) { + this.embeddingService = new EmbeddingService(this.modelsDir); + await this.embeddingService.init(); + } - // Initialize both services - await Promise.all([ - this.embeddingService.init(), - this.classifierService.init(), - ]); + // Keep legacy classifier for backward compatibility + if (config.provider === 'local' && !this.classifierService) { + this.classifierService = new ClassifierService(this.modelsDir); + await this.classifierService.init(); + } this.aiEnabled = true; } @@ -54,6 +97,14 @@ export class Analyzer { return this.aiEnabled; } + getActiveProviderType(): ProviderType | null { + return this.providerManager?.getProviderType() || null; + } + + getActiveProviderName(): string | null { + return this.activeProvider?.name || null; + } + async analyze(filePath: string, useAI = false): Promise { const filename = basename(filePath); const ext = extname(filename).toLowerCase().slice(1); @@ -95,10 +146,10 @@ export class Analyzer { // Type-specific analysis failed } - // AI classification if enabled - if (useAI && this.aiEnabled && this.classifierService) { + // AI classification if enabled - use active provider + if (useAI && this.aiEnabled && this.activeProvider) { try { - const aiResult = await this.classifierService.classifyFile({ + const aiResult = await this.activeProvider.classifyFile({ filename: analysis.filename, content: analysis.textContent, metadata: analysis.metadata as Record | undefined, @@ -177,10 +228,10 @@ export class Analyzer { // Type-specific analysis failed } - // AI classification if enabled - if (useAI && this.aiEnabled && this.classifierService) { + // AI classification if enabled - use active provider + if (useAI && this.aiEnabled && this.activeProvider) { try { - const aiResult = await this.classifierService.classifyFile({ + const aiResult = await this.activeProvider.classifyFile({ filename: analysis.filename, content: analysis.textContent, metadata: analysis.metadata as Record | undefined, @@ -199,7 +250,7 @@ export class Analyzer { category: string; confidence: number; }> { - if (!this.aiEnabled || !this.classifierService) { + if (!this.aiEnabled || !this.activeProvider) { return { category: analysis.category, confidence: 0.5, @@ -207,7 +258,7 @@ export class Analyzer { } try { - return await this.classifierService.classifyFile({ + return await this.activeProvider.classifyFile({ filename: analysis.filename, content: analysis.textContent, metadata: analysis.metadata as Record | undefined, diff --git a/src/index.ts b/src/index.ts index fede4df..9b98783 100644 --- a/src/index.ts +++ b/src/index.ts @@ -15,6 +15,24 @@ export { EmbeddingService } from './ai/embeddings.js'; export { ClassifierService } from './ai/classifier.js'; export { OCRService } from './ai/ocr.js'; +// AI Providers exports +export { + type AIProvider, + type ProviderConfig, + type ProviderType, + type ClassificationRequest, + type ClassificationResult, + type ProviderManagerConfig, + ProviderManager, + createProvider, + listProviders, + OpenAIProvider, + AnthropicProvider, + GeminiProvider, + OllamaProvider, + LocalProvider, +} from './ai/providers/index.js'; + // Config exports export { loadConfig,