diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index 3140693..0000000 --- a/Dockerfile +++ /dev/null @@ -1,10 +0,0 @@ -FROM golang:1.23 -WORKDIR /app -COPY . . -## for sqlite3 -ENV CGO_ENABLED=1 - -RUN go mod download -## fts5 is needed for sqlite full text search -RUN go build -tags "fts5" -o panda -CMD ["./panda"] diff --git a/Makefile b/Makefile index a30f394..705b297 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ -local-build: - go build -o panda +build: + go build -o panda --tags "fts5" -local-run: local-build +run: build -./panda rm ./panda diff --git a/docker-compose.yml b/docker-compose.yml deleted file mode 100644 index ce83ad0..0000000 --- a/docker-compose.yml +++ /dev/null @@ -1,27 +0,0 @@ -version: "3" - -services: - panda: - build: - context: . - dockerfile: Dockerfile - environment: - - PANDA_ENV=dev - - PANDA_DATA_DIR_PATH=/panda/data - - PANDA_DATABSE_NAME=panda.db - volumes: - - db_data:/panda/data - test-panda: - build: - context: . - dockerfile: Dockerfile - environment: - - PANDA_ENV=dev - - PANDA_DATA_DIR_PATH=/panda/data - - PANDA_DATABSE_NAME=panda.db - volumes: - - db_test_data:/panda/data - command: ["go", "test", "-v", "-tags", "fts5", "./..."] -volumes: - db_data: {} - db_test_data: {} \ No newline at end of file diff --git a/internal/config/config.go b/internal/config/config.go index 7d42ae2..c688faf 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -31,6 +31,14 @@ func GetDir() string { return filepath.Join(configDir, appConfigDir) } +func GetDataDir() string { + dataDir := xdg.DataHome + if dataDir == "" { + dataDir = filepath.Join(xdg.Home, ".local", "share") + } + return filepath.Join(dataDir, appConfigDir) +} + func GetFilePath() string { configDir := GetDir() return filepath.Join(configDir, configFileName) diff --git a/internal/db/db.go b/internal/db/db.go index cdb35b9..6775351 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -2,10 +2,12 @@ package db import ( "fmt" - "github.com/jmoiron/sqlx" - _ "github.com/mattn/go-sqlite3" "os" "path/filepath" + + "github.com/aavshr/panda/internal/utils" + "github.com/jmoiron/sqlx" + _ "github.com/mattn/go-sqlite3" ) type Config struct { @@ -18,11 +20,10 @@ type Store struct { } func New(config Config, schemaInit, migrations *string) (*Store, error) { - // 0755 = rwxr-xr-x if err := os.MkdirAll(config.DataDirPath, 0755); err != nil { return nil, fmt.Errorf("could not make data dir, os.MkdirAll: %w", err) } - f, err := os.Create(filepath.Join(config.DataDirPath, config.DatabaseName)) + f, err := os.OpenFile(filepath.Join(config.DataDirPath, config.DatabaseName), os.O_RDWR|os.O_CREATE, 0644) if err != nil { return nil, fmt.Errorf("could not create database file, os.OpenFile: %w", err) } @@ -109,6 +110,29 @@ func (s *Store) DeleteAllThreadsTx(tx *sqlx.Tx) error { return nil } +func (s *Store) CreateMessage(message *Message) error { + // TODO: is this implicit behavior okay? + if message.ID == "" { + messageID, err := utils.RandomID() + if err != nil { + return fmt.Errorf("could not generate random id, utils.RandomID: %w", err) + } + message.ID = messageID + } + tx, err := s.db.Beginx() + if err != nil { + return fmt.Errorf("could not start transaction, db.Beginx: %w", err) + } + if err := s.CreateMessageTx(tx, message); err != nil { + tx.Rollback() + return fmt.Errorf("could not create message, CreateMessageTx: %w", err) + } + if err := tx.Commit(); err != nil { + return fmt.Errorf("could not commit transaction, tx.Commit: %w", err) + } + return nil +} + func (s *Store) CreateMessageTx(tx *sqlx.Tx, message *Message) error { query := `INSERT INTO messages (id, m_role, content, created_at, thread_id) VALUES (:id, :m_role, :content, :created_at, :thread_id)` @@ -124,7 +148,7 @@ func (s *Store) CreateMessageTx(tx *sqlx.Tx, message *Message) error { func (s *Store) ListMessagesByThreadIDPaginated(threadID string, offset, limit int) ([]*Message, error) { var messages []*Message - err := s.db.Select(&messages, "SELECT * FROM messages WHERE thread_id = $1 ORDER BY created_at DESC LIMIT $2 OFFSET $3", threadID, limit, offset) + err := s.db.Select(&messages, "SELECT * FROM messages WHERE thread_id = $1 ORDER BY created_at LIMIT $2 OFFSET $3", threadID, limit, offset) if err != nil { return nil, fmt.Errorf("could not select messages, db.Select: %w", err) } @@ -148,3 +172,76 @@ func (s *Store) SearchMessageContentPaginated(term string, offset, limit int) ([ } return threads, nil } + +func (s *Store) UpsertThread(thread *Thread) error { + tx, err := s.db.Beginx() + if err != nil { + return fmt.Errorf("could not start transaction, db.Beginx: %w", err) + } + if err := s.UpsertThreadTx(tx, thread); err != nil { + tx.Rollback() + return fmt.Errorf("could not upsert thread, UpsertThreadTx: %w", err) + } + if err := tx.Commit(); err != nil { + return fmt.Errorf("could not commit transaction, tx.Commit: %w", err) + } + return nil +} + +func (s *Store) UpsertThreadTx(tx *sqlx.Tx, thread *Thread) error { + query := `INSERT INTO threads (id, t_name, created_at, updated_at, external_message_store) + VALUES (:id, :t_name, :created_at, :updated_at, :external_message_store) + ON CONFLICT(id) DO UPDATE SET t_name = :t_name, updated_at = :updated_at` + if _, err := tx.NamedExec(query, thread); err != nil { + return fmt.Errorf("tx.NamedExec: %w", err) + } + query = `DELETE FROM virtual_thread_names WHERE thread_id = $1` + if _, err := tx.Exec(query, thread.ID); err != nil { + return fmt.Errorf("tx.Exec: %w", err) + } + query = `INSERT INTO virtual_thread_names (thread_id, thread_name) VALUES ($1, $2)` + if _, err := tx.Exec(query, thread.ID, thread.Name); err != nil { + return fmt.Errorf("tx.Exec: %w", err) + } + return nil +} + +func (s *Store) DeleteThread(id string) error { + tx, err := s.db.Beginx() + if err != nil { + return fmt.Errorf("could not start transaction, db.Beginx: %w", err) + } + if err := s.DeleteThreadTx(tx, id); err != nil { + tx.Rollback() + return fmt.Errorf("could not delete thread, DeleteThreadTx: %w", err) + } + if err := tx.Commit(); err != nil { + return fmt.Errorf("could not commit transaction, tx.Commit: %w", err) + } + return nil +} + +func (s *Store) DeleteAllThreads() error { + tx, err := s.db.Beginx() + if err != nil { + return fmt.Errorf("could not start transaction, db.Beginx: %w", err) + } + query := `DELETE FROM threads` + if _, err := tx.Exec(query); err != nil { + return fmt.Errorf("could not delete threads, db.Exec: %w", err) + } + query = `DELETE FROM virtual_thread_names` + if _, err := tx.Exec(query); err != nil { + tx.Rollback() + return fmt.Errorf("could not delete virtual thread names, db.Exec: %w", err) + } + query = `DELETE FROM virtual_message_content` + if _, err := tx.Exec(query); err != nil { + tx.Rollback() + return fmt.Errorf("could not delete virtual message content, db.Exec: %w", err) + } + if err := tx.Commit(); err != nil { + return fmt.Errorf("could not commit transaction, tx.Commit: %w", err) + } + return nil +} diff --git a/internal/db/schema/init.sql b/internal/db/schema/init.sql index cba4be7..92dfee2 100644 --- a/internal/db/schema/init.sql +++ b/internal/db/schema/init.sql @@ -14,13 +14,13 @@ CREATE TABLE IF NOT EXISTS messages ( thread_id TEXT REFERENCES threads(id) ON DELETE CASCADE ); -CREATE VIRTUAL TABLE virtual_thread_names USING fts5( +CREATE VIRTUAL TABLE IF NOT EXISTS virtual_thread_names USING fts5( thread_name, thread_id UNINDEXED ); -CREATE VIRTUAL TABLE virtual_message_content USING fts5( +CREATE VIRTUAL TABLE IF NOT EXISTS virtual_message_content USING fts5( message_content, message_id UNINDEXED, thread_id UNINDEXED -); \ No newline at end of file +); diff --git a/internal/ui/handlers.go b/internal/ui/handlers.go index 6a1e92c..5e371c5 100644 --- a/internal/ui/handlers.go +++ b/internal/ui/handlers.go @@ -179,19 +179,11 @@ func (m *Model) handleEscapeMsg() { } func (m *Model) handleListEnterMsg(msg components.ListEnterMsg) tea.Cmd { - // TODO: handle entering a message for messages - // either copy the entire message or let user copy only specific parts switch m.focusedComponent { case components.ComponentHistory: - // first item is always for new thread - if msg.Index == 0 { - m.setMessages([]*db.Message{}) - m.setSelectedComponent(components.ComponentChatInput) - m.setFocusedComponent(components.ComponentChatInput) - return nil - } - m.setSelectedComponent(components.ComponentMessages) - m.setFocusedComponent(components.ComponentMessages) + m.setSelectedComponent(components.ComponentChatInput) + m.setFocusedComponent(components.ComponentChatInput) + return nil } return nil } diff --git a/internal/ui/model.go b/internal/ui/model.go index 124dcd9..99a5a90 100644 --- a/internal/ui/model.go +++ b/internal/ui/model.go @@ -121,10 +121,8 @@ func New(conf *Config, store store.Store, llm llm.LLM) (*Model, error) { if err != nil { return m, fmt.Errorf("store.ListLatestThreadsPaginated %w", err) } - m.threads = append(m.threads, threads...) m.messages = []*db.Message{} - m.historyModel = components.NewListModel(&components.NewListModelInput{ Title: titleHistory, Items: components.NewThreadListItems(m.threads), @@ -133,6 +131,7 @@ func New(conf *Config, store store.Store, llm llm.LLM) (*Model, error) { Delegate: components.NewThreadListItemDelegate(), AllowInfiniteScrolling: false, }) + m.setThreads(append(m.threads, threads...)) m.historyModel.Select(0) // New Thread is selected by default m.messagesModel = components.NewChatModel(conf.messagesWidth, conf.messagesHeight) m.chatInputModel = components.NewChatInputModel(conf.chatInputWidth, conf.chatInputHeight) diff --git a/internal/ui/store/store.go b/internal/ui/store/store.go index e79fe66..f606a05 100644 --- a/internal/ui/store/store.go +++ b/internal/ui/store/store.go @@ -8,7 +8,6 @@ type Store interface { ListLatestThreadsPaginated(offset, limit int) ([]*db.Thread, error) ListMessagesByThreadIDPaginated(threadID string, offset, limit int) ([]*db.Message, error) UpsertThread(thread *db.Thread) error - UpdateThreadName(threadID, name string) error DeleteThread(threadID string) error DeleteAllThreads() error CreateMessage(message *db.Message) error diff --git a/main.go b/main.go index 57ac681..2245f2d 100644 --- a/main.go +++ b/main.go @@ -5,16 +5,15 @@ import ( "log" "os" + "strings" + + "github.com/aavshr/panda/internal/config" "github.com/aavshr/panda/internal/db" "github.com/aavshr/panda/internal/llm/openai" "github.com/aavshr/panda/internal/ui" - //"github.com/aavshr/panda/internal/ui/llm" "github.com/aavshr/panda/internal/ui/store" tea "github.com/charmbracelet/bubbletea" "golang.org/x/term" - //"log" - //"os" - //"strings" ) //go:embed internal/db/schema/init.sql @@ -24,35 +23,10 @@ var dbSchemaInit string var dbSchemaMigrations string const ( - DefaultDataDirPath = "/.local/share/panda/data" DefaultDatabaseName = "panda.db" ) -func main() { - /* - isDev := strings.ToLower(os.Getenv("PANDA_ENV")) == "dev" - dataDirPath := DefaultDataDirPath - databaseName := DefaultDatabaseName - if isDev { - devDataDirPath := os.Getenv("PANDA_DATA_DIR_PATH") - devDatabaseName := os.Getenv("PANDA_DATABASE_NAME") - if devDataDirPath != "" { - dataDirPath = devDataDirPath - } - if devDatabaseName != "" { - databaseName = devDatabaseName - } - } - - _, err := db.New(db.Config{ - DataDirPath: dataDirPath, - DatabaseName: databaseName, - }, &dbSchemaInit, &dbSchemaMigrations) - if err != nil { - log.Fatal("failed to initialize db: ", err) - } - */ - +func initMockStore() *store.Mock { testThreads := []*db.Thread{ { ID: "1", @@ -98,8 +72,32 @@ func main() { }, } - mockStore := store.NewMock(testThreads, testMessages) - //mockLLM := llm.NewMock() + return store.NewMock(testThreads, testMessages) +} + +func main() { + isDev := strings.ToLower(os.Getenv("PANDA_ENV")) == "dev" + dataDirPath := config.GetDataDir() + databaseName := DefaultDatabaseName + if isDev { + devDataDirPath := os.Getenv("PANDA_DATA_DIR_PATH") + devDatabaseName := os.Getenv("PANDA_DATABASE_NAME") + if devDataDirPath != "" { + dataDirPath = devDataDirPath + } + if devDatabaseName != "" { + databaseName = devDatabaseName + } + } + + dbStore, err := db.New(db.Config{ + DataDirPath: dataDirPath, + DatabaseName: databaseName, + }, &dbSchemaInit, &dbSchemaMigrations) + if err != nil { + log.Fatal("failed to initialize db: ", err) + } + openaiLLM := openai.New("") width, height, err := term.GetSize(int(os.Stdout.Fd())) @@ -110,9 +108,10 @@ func main() { m, err := ui.New(&ui.Config{ InitThreadsLimit: 10, MaxThreadsLimit: 100, + MessagesLimit: 50, Width: width - 8, Height: height - 10, - }, mockStore, openaiLLM) + }, dbStore, openaiLLM) if err != nil { log.Fatal("ui.New: ", err) }