agent test zed
This commit is contained in:
parent
41b597f8fc
commit
7728a4c767
34
Dockerfile
Normal file
34
Dockerfile
Normal file
@ -0,0 +1,34 @@
|
||||
FROM golang:1.21-alpine
|
||||
|
||||
# Install build dependencies
|
||||
RUN apk add --no-cache gcc musl-dev sqlite-dev
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy go mod files first for better caching
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
|
||||
# Copy source code
|
||||
COPY . .
|
||||
|
||||
# Build the application
|
||||
RUN CGO_ENABLED=1 go build -o /todo-app
|
||||
|
||||
# Install Ollama for AI analysis
|
||||
RUN wget https://ollama.ai/install.sh -O install.sh && \
|
||||
chmod +x install.sh && \
|
||||
./install.sh
|
||||
|
||||
# Create volume for persistent data
|
||||
VOLUME /data
|
||||
|
||||
# Set environment variables
|
||||
ENV DB_PATH=/data/todoai.db \
|
||||
AI_MODEL_NAME=deepseek-r1:latest \
|
||||
AI_PROVIDER=ollama
|
||||
|
||||
# Expose port if needed for future API/web interface
|
||||
EXPOSE 8080
|
||||
|
||||
CMD ["/todo-app"]
|
||||
21
LICENSE
Normal file
21
LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 TodoAI Contributors
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
41
Makefile
Normal file
41
Makefile
Normal file
@ -0,0 +1,41 @@
|
||||
.PHONY: build test clean run docker-build docker-run
|
||||
|
||||
# Build variables
|
||||
BINARY_NAME=todo
|
||||
VERSION=$(shell git describe --tags --always --dirty)
|
||||
BUILD_TIME=$(shell date +%FT%T%z)
|
||||
LDFLAGS=-ldflags "-X main.Version=${VERSION} -X main.BuildTime=${BUILD_TIME}"
|
||||
|
||||
build:
|
||||
go build ${LDFLAGS} -o ${BINARY_NAME}
|
||||
|
||||
test:
|
||||
go test -v ./...
|
||||
|
||||
clean:
|
||||
go clean
|
||||
rm -f ${BINARY_NAME}
|
||||
rm -f todo.log
|
||||
rm -f test.db
|
||||
|
||||
run: build
|
||||
./${BINARY_NAME}
|
||||
|
||||
docker-build:
|
||||
docker build -t todoai:${VERSION} .
|
||||
|
||||
docker-run:
|
||||
docker run -it --rm \
|
||||
-v ${PWD}/data:/data \
|
||||
-e DB_PATH=/data/todoai.db \
|
||||
todoai:${VERSION}
|
||||
|
||||
install-deps:
|
||||
go mod download
|
||||
|
||||
lint:
|
||||
go vet ./...
|
||||
go fmt ./...
|
||||
|
||||
# Run this before committing
|
||||
pre-commit: lint test
|
||||
33
go.mod
Normal file
33
go.mod
Normal file
@ -0,0 +1,33 @@
|
||||
module github.com/your-username/todo
|
||||
|
||||
go 1.23.0
|
||||
|
||||
toolchain go1.23.6
|
||||
|
||||
require (
|
||||
github.com/charmbracelet/bubbles v0.16.1
|
||||
github.com/charmbracelet/bubbletea v0.24.2
|
||||
github.com/charmbracelet/lipgloss v0.7.1
|
||||
github.com/mattn/go-sqlite3 v1.14.17
|
||||
golang.org/x/time v0.11.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/atotto/clipboard v0.1.4 // indirect
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||
github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.18 // indirect
|
||||
github.com/mattn/go-localereader v0.0.1 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.14 // indirect
|
||||
github.com/muesli/ansi v0.0.0-20211018074035-2e021307bc4b // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/muesli/reflow v0.3.0 // indirect
|
||||
github.com/muesli/termenv v0.15.1 // indirect
|
||||
github.com/rivo/uniseg v0.2.0 // indirect
|
||||
github.com/sahilm/fuzzy v0.1.0 // indirect
|
||||
golang.org/x/sync v0.1.0 // indirect
|
||||
golang.org/x/sys v0.10.0 // indirect
|
||||
golang.org/x/term v0.10.0 // indirect
|
||||
golang.org/x/text v0.3.8 // indirect
|
||||
)
|
||||
50
go.sum
Normal file
50
go.sum
Normal file
@ -0,0 +1,50 @@
|
||||
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
|
||||
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||
github.com/charmbracelet/bubbles v0.16.1 h1:6uzpAAaT9ZqKssntbvZMlksWHruQLNxg49H5WdeuYSY=
|
||||
github.com/charmbracelet/bubbles v0.16.1/go.mod h1:2QCp9LFlEsBQMvIYERr7Ww2H2bA7xen1idUDIzm/+Xc=
|
||||
github.com/charmbracelet/bubbletea v0.24.2 h1:uaQIKx9Ai6Gdh5zpTbGiWpytMU+CfsPp06RaW2cx/SY=
|
||||
github.com/charmbracelet/bubbletea v0.24.2/go.mod h1:XdrNrV4J8GiyshTtx3DNuYkR1FDaJmO3l2nejekbsgg=
|
||||
github.com/charmbracelet/lipgloss v0.7.1 h1:17WMwi7N1b1rVWOjMT+rCh7sQkvDU75B2hbZpc5Kc1E=
|
||||
github.com/charmbracelet/lipgloss v0.7.1/go.mod h1:yG0k3giv8Qj8edTCbbg6AlQ5e8KNWpFujkNawKNhE2c=
|
||||
github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81 h1:q2hJAaP1k2wIvVRd/hEHD7lacgqrCPS+k8g1MndzfWY=
|
||||
github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81/go.mod h1:YynlIjWYF8myEu6sdkwKIvGQq+cOckRm6So2avqoYAk=
|
||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mattn/go-isatty v0.0.18 h1:DOKFKCQ7FNG2L1rbrmstDN4QVRdS89Nkh85u68Uwp98=
|
||||
github.com/mattn/go-isatty v0.0.18/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
|
||||
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
|
||||
github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
|
||||
github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
|
||||
github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
|
||||
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
github.com/muesli/ansi v0.0.0-20211018074035-2e021307bc4b h1:1XF24mVaiu7u+CFywTdcDo2ie1pzzhwjt6RHqzpMU34=
|
||||
github.com/muesli/ansi v0.0.0-20211018074035-2e021307bc4b/go.mod h1:fQuZ0gauxyBcmsdE3ZT4NasjaRdxmbCS0jRHsrWu3Ho=
|
||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
|
||||
github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s=
|
||||
github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8=
|
||||
github.com/muesli/termenv v0.15.1 h1:UzuTb/+hhlBugQz28rpzey4ZuKcZ03MeKsoG7IJZIxs=
|
||||
github.com/muesli/termenv v0.15.1/go.mod h1:HeAQPTzpfs016yGtA4g00CsdYnVLJvxsS4ANqrZs2sQ=
|
||||
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/sahilm/fuzzy v0.1.0 h1:FzWGaw2Opqyu+794ZQ9SYifWv2EIXpwP4q8dY1kDAwI=
|
||||
github.com/sahilm/fuzzy v0.1.0/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y=
|
||||
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.10.0 h1:3R7pNqamzBraeqj/Tj8qt1aQ2HpmlC+Cx/qL/7hn4/c=
|
||||
golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o=
|
||||
golang.org/x/text v0.3.8 h1:nAL+RVCQ9uMn3vJZbV+MRnydTJFPf8qqY42YiA6MrqY=
|
||||
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
|
||||
golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0=
|
||||
golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||
44
improvements.md
Normal file
44
improvements.md
Normal file
@ -0,0 +1,44 @@
|
||||
# Todo App Improvement Plan
|
||||
|
||||
## 1. Code Structure Improvements
|
||||
|
||||
- Split code into multiple files for better organization (e.g., database.go, models.go, ui.go, etc.)
|
||||
- Use dependency injection to make code more testable
|
||||
- Add proper error handling and logging throughout the application
|
||||
|
||||
## 2. Database Enhancements
|
||||
|
||||
- Add database migrations to handle schema changes
|
||||
- Implement connection pooling for better performance
|
||||
- Add indexes to improve query performance
|
||||
|
||||
## 3. UI/UX Improvements
|
||||
|
||||
- Add support for task editing and deletion
|
||||
- Add search/filter functionality for tasks
|
||||
- Add support for task categories and tags
|
||||
- Add keyboard shortcuts for common actions
|
||||
|
||||
## 4. AI Integration
|
||||
|
||||
- Add support for multiple AI models
|
||||
- Add configuration file for AI model settings
|
||||
- Add fallback mechanism if AI model is unavailable
|
||||
|
||||
## 5. Testing
|
||||
|
||||
- Add unit tests for core functionality
|
||||
- Add integration tests for database and AI integration
|
||||
- Add end-to-end tests for the TUI
|
||||
|
||||
## 6. Documentation
|
||||
|
||||
- Add README with setup instructions and usage
|
||||
- Add API documentation for AI integration
|
||||
- Add changelog to track changes
|
||||
|
||||
## 7. Deployment
|
||||
|
||||
- Add Dockerfile for containerized deployment
|
||||
- Add CI/CD pipeline for automated testing and deployment
|
||||
- Add support for configuration via environment variables
|
||||
5
logs/analysis.log
Normal file
5
logs/analysis.log
Normal file
@ -0,0 +1,5 @@
|
||||
[2025-05-08 20:59:59] Task: "go to gym" Result: {"priority":3,"category":"work","estimate":15,"related":"","summary":"go to gym"}
|
||||
[2025-05-08 21:00:22] Task: "test" Result: {"priority":3,"category":"work","estimate":15,"related":"","summary":"test"}
|
||||
[2025-05-08 21:00:49] Task: "visit dentist" Result: {"priority":3,"category":"personal","estimate":15,"related":"","summary":"visit dentist"}
|
||||
[2025-05-08 21:07:37] Task: "test" Result: {"priority":3,"category":"work","estimate":15,"related":"","summary":"test"}
|
||||
[2025-05-08 21:07:53] Task: "work on communiction skills" Result: {"priority":3,"category":"work","estimate":30,"related":"","summary":"work on communiction skills"}
|
||||
66
main.go
Normal file
66
main.go
Normal file
@ -0,0 +1,66 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
|
||||
"github.com/your-username/todo/src/ai"
|
||||
"github.com/your-username/todo/src/config"
|
||||
"github.com/your-username/todo/src/database"
|
||||
"github.com/your-username/todo/src/ui"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Setup logging
|
||||
logFile, err := os.OpenFile("todo.log", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer logFile.Close()
|
||||
log.SetOutput(logFile)
|
||||
|
||||
// Create context that can be cancelled on signal
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Handle shutdown signals
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigChan
|
||||
cancel()
|
||||
}()
|
||||
|
||||
// Load configuration
|
||||
cfg, err := config.LoadConfig()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Initialize database
|
||||
db, err := database.NewDatabase(cfg.Database)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to initialize database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Initialize AI service
|
||||
aiService := ai.NewAIService(&cfg.AI)
|
||||
|
||||
// Start background analyzer if enabled
|
||||
if cfg.UI.IdleStartHour >= 0 {
|
||||
analyzer := ai.NewBackgroundAnalyzer(db, aiService, cfg)
|
||||
go analyzer.Start(ctx)
|
||||
}
|
||||
|
||||
// Initialize and start UI
|
||||
if err := tea.NewProgram(ui.New(db, aiService)).Start(); err != nil {
|
||||
log.Printf("Error running program: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
203
src/ai/background.go
Normal file
203
src/ai/background.go
Normal file
@ -0,0 +1,203 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/your-username/todo/src/config"
|
||||
"github.com/your-username/todo/src/database"
|
||||
)
|
||||
|
||||
type BackgroundAnalyzer struct {
|
||||
db database.Database
|
||||
aiService *AIService
|
||||
config *config.Config
|
||||
rateLimiter *rate.Limiter
|
||||
stats *analyzerStats
|
||||
}
|
||||
|
||||
type analyzerStats struct {
|
||||
mu sync.Mutex
|
||||
totalProcessed int64
|
||||
totalSuccesses int64
|
||||
totalFailures int64
|
||||
lastProcessed time.Time
|
||||
lastError error
|
||||
consecutiveErrs int
|
||||
backoffUntil time.Time
|
||||
}
|
||||
|
||||
func NewBackgroundAnalyzer(db database.Database, aiService *AIService, cfg *config.Config) *BackgroundAnalyzer {
|
||||
// Process one task per minute by default
|
||||
rps := rate.Every(time.Minute)
|
||||
if cfg.AI.RequestsPerMinute > 0 {
|
||||
rps = rate.Every(time.Minute / time.Duration(cfg.AI.RequestsPerMinute))
|
||||
}
|
||||
|
||||
return &BackgroundAnalyzer{
|
||||
db: db,
|
||||
aiService: aiService,
|
||||
config: cfg,
|
||||
rateLimiter: rate.NewLimiter(rps, 1), // Burst of 1 for background tasks
|
||||
stats: &analyzerStats{},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BackgroundAnalyzer) Start(ctx context.Context) {
|
||||
log.Printf("Starting background analyzer. Idle hours: %d:00-%d:00",
|
||||
b.config.UI.IdleStartHour,
|
||||
b.config.UI.IdleEndHour)
|
||||
|
||||
ticker := time.NewTicker(time.Duration(b.config.UI.RefreshInterval) * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Println("Background analyzer shutting down...")
|
||||
return
|
||||
case <-ticker.C:
|
||||
if !b.isIdleTime() {
|
||||
continue
|
||||
}
|
||||
|
||||
if time.Now().Before(b.stats.backoffUntil) {
|
||||
continue
|
||||
}
|
||||
|
||||
if b.stats.consecutiveErrs >= 5 {
|
||||
backoff := 5 * time.Minute * time.Duration(1<<uint(b.stats.consecutiveErrs-5))
|
||||
if backoff > 1*time.Hour {
|
||||
backoff = 1 * time.Hour
|
||||
}
|
||||
log.Printf("Too many consecutive errors (%d), backing off for %v",
|
||||
b.stats.consecutiveErrs, backoff)
|
||||
b.stats.backoffUntil = time.Now().Add(backoff)
|
||||
continue
|
||||
}
|
||||
|
||||
b.processPendingTasks(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BackgroundAnalyzer) isIdleTime() bool {
|
||||
hour := time.Now().Hour()
|
||||
start := b.config.UI.IdleStartHour
|
||||
end := b.config.UI.IdleEndHour
|
||||
|
||||
if start < end {
|
||||
return hour >= start && hour < end
|
||||
}
|
||||
// Handle overnight period (e.g., 22:00 - 06:00)
|
||||
return hour >= start || hour < end
|
||||
}
|
||||
|
||||
func (b *BackgroundAnalyzer) processPendingTasks(ctx context.Context) {
|
||||
// Get tasks without analysis
|
||||
tasks, err := b.db.GetTasks(ctx, false)
|
||||
if err != nil {
|
||||
log.Printf("Error getting tasks: %v", err)
|
||||
b.recordError(err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, task := range tasks {
|
||||
// Check rate limit
|
||||
if err := b.rateLimiter.Wait(ctx); err != nil {
|
||||
log.Printf("Rate limit wait failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
// Check if task already has analysis
|
||||
analysis, err := b.db.GetTaskAnalysis(ctx, int(task.ID))
|
||||
if err != nil {
|
||||
log.Printf("Error checking task analysis for task %d: %v", task.ID, err)
|
||||
b.recordError(err)
|
||||
continue
|
||||
}
|
||||
if analysis != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
taskCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
|
||||
// Analyze task
|
||||
result, err := b.aiService.AnalyzeTask(taskCtx, task.Title)
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Error analyzing task %d: %v", task.ID, err)
|
||||
b.recordError(err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Store analysis result
|
||||
jsonResult, err := result.ToJSON()
|
||||
if err != nil {
|
||||
log.Printf("Error serializing analysis for task %d: %v", task.ID, err)
|
||||
b.recordError(err)
|
||||
continue
|
||||
}
|
||||
|
||||
err = b.db.StoreAnalysis(ctx, &database.Analysis{
|
||||
TaskID: task.ID,
|
||||
Result: jsonResult,
|
||||
AnalyzedAt: time.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("Error storing analysis for task %d: %v", task.ID, err)
|
||||
b.recordError(err)
|
||||
continue
|
||||
}
|
||||
|
||||
b.recordSuccess()
|
||||
log.Printf("Successfully analyzed task %d: %s", task.ID, result.Summary)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BackgroundAnalyzer) recordError(err error) {
|
||||
b.stats.mu.Lock()
|
||||
defer b.stats.mu.Unlock()
|
||||
|
||||
b.stats.totalFailures++
|
||||
b.stats.consecutiveErrs++
|
||||
b.stats.lastError = err
|
||||
b.stats.lastProcessed = time.Now()
|
||||
}
|
||||
|
||||
func (b *BackgroundAnalyzer) recordSuccess() {
|
||||
b.stats.mu.Lock()
|
||||
defer b.stats.mu.Unlock()
|
||||
|
||||
b.stats.totalSuccesses++
|
||||
b.stats.totalProcessed++
|
||||
b.stats.consecutiveErrs = 0
|
||||
b.stats.lastProcessed = time.Now()
|
||||
}
|
||||
|
||||
// GetStats returns current analyzer statistics
|
||||
func (b *BackgroundAnalyzer) GetStats() *analyzerStats {
|
||||
b.stats.mu.Lock()
|
||||
defer b.stats.mu.Unlock()
|
||||
|
||||
// Return a copy to prevent concurrent access issues
|
||||
return &analyzerStats{
|
||||
totalProcessed: b.stats.totalProcessed,
|
||||
totalSuccesses: b.stats.totalSuccesses,
|
||||
totalFailures: b.stats.totalFailures,
|
||||
lastProcessed: b.stats.lastProcessed,
|
||||
lastError: b.stats.lastError,
|
||||
consecutiveErrs: b.stats.consecutiveErrs,
|
||||
backoffUntil: b.stats.backoffUntil,
|
||||
}
|
||||
}
|
||||
439
src/ai/service.go
Normal file
439
src/ai/service.go
Normal file
@ -0,0 +1,439 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/your-username/todo/src/config"
|
||||
)
|
||||
|
||||
type AIService struct {
|
||||
config *config.AIConfig
|
||||
client *http.Client
|
||||
limiter *rate.Limiter
|
||||
}
|
||||
|
||||
type AnalysisResult struct {
|
||||
Priority int `json:"priority"`
|
||||
Category string `json:"category"`
|
||||
TimeEstimate int `json:"estimate"`
|
||||
RelatedTasks string `json:"related"`
|
||||
Summary string `json:"summary"`
|
||||
}
|
||||
|
||||
type ollamaRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Stream bool `json:"stream"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
}
|
||||
|
||||
type ollamaResponse struct {
|
||||
Response string `json:"response"`
|
||||
}
|
||||
|
||||
var (
|
||||
analysisLogFile *os.File
|
||||
analysisLogger *log.Logger
|
||||
logMutex sync.Mutex
|
||||
)
|
||||
|
||||
func InitAnalysisLogger(logPath string) error {
|
||||
var err error
|
||||
analysisLogFile, err = os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open analysis log file: %w", err)
|
||||
}
|
||||
analysisLogger = log.New(analysisLogFile, "", log.LstdFlags)
|
||||
return nil
|
||||
}
|
||||
|
||||
func CloseAnalysisLogger() error {
|
||||
if analysisLogFile != nil {
|
||||
return analysisLogFile.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func logAnalysis(taskTitle string, result *AnalysisResult, source string) {
|
||||
if analysisLogger == nil {
|
||||
return
|
||||
}
|
||||
|
||||
logMutex.Lock()
|
||||
defer logMutex.Unlock()
|
||||
|
||||
analysisLogger.Printf("Task: %q\nSource: %s\nPriority: %d\nCategory: %s\nTime Estimate: %d minutes\nSummary: %s\nRelated Tasks: %s\n---\n",
|
||||
taskTitle,
|
||||
source,
|
||||
result.Priority,
|
||||
result.Category,
|
||||
result.TimeEstimate,
|
||||
result.Summary,
|
||||
result.RelatedTasks)
|
||||
}
|
||||
|
||||
func (s *AIService) logAnalysis(taskTitle string, result *AnalysisResult, inputErr error) {
|
||||
logDir := "logs"
|
||||
if err := os.MkdirAll(logDir, 0755); err != nil {
|
||||
log.Printf("Failed to create log directory: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
logFile := filepath.Join(logDir, "analysis.log")
|
||||
f, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
log.Printf("Failed to open analysis log file: %v", err)
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
timestamp := time.Now().Format("2006-01-02 15:04:05")
|
||||
var logEntry string
|
||||
if inputErr != nil {
|
||||
logEntry = fmt.Sprintf("[%s] Task: %q Error: %v\n", timestamp, taskTitle, inputErr)
|
||||
} else {
|
||||
jsonResult, _ := result.ToJSON()
|
||||
logEntry = fmt.Sprintf("[%s] Task: %q Result: %s\n", timestamp, taskTitle, jsonResult)
|
||||
}
|
||||
|
||||
if _, err := f.WriteString(logEntry); err != nil {
|
||||
log.Printf("Failed to write to analysis log: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func NewAIService(cfg *config.AIConfig) *AIService {
|
||||
client := &http.Client{
|
||||
Timeout: time.Duration(cfg.Timeout) * time.Second,
|
||||
}
|
||||
|
||||
// Create rate limiter - default 10 requests per minute
|
||||
rps := rate.Every(time.Minute / time.Duration(cfg.RequestsPerMinute))
|
||||
limiter := rate.NewLimiter(rps, cfg.BurstLimit)
|
||||
|
||||
return &AIService{
|
||||
config: cfg,
|
||||
client: client,
|
||||
limiter: limiter,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AIService) AnalyzeTask(ctx context.Context, taskTitle string) (*AnalysisResult, error) {
|
||||
var lastErr error
|
||||
|
||||
// Wait for rate limiter
|
||||
if err := s.limiter.Wait(ctx); err != nil {
|
||||
return s.fallbackAnalysis(taskTitle), nil
|
||||
}
|
||||
|
||||
// Try each configured model in order
|
||||
for _, model := range s.config.Models {
|
||||
var analyzer TaskAnalyzer
|
||||
switch model.Provider {
|
||||
case "ollama":
|
||||
analyzer = &OllamaAnalyzer{
|
||||
baseURL: s.config.OllamaBaseURL,
|
||||
model: model.Name,
|
||||
client: s.client,
|
||||
config: s.config,
|
||||
}
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
// Try with retries
|
||||
for attempt := 0; attempt < s.config.MaxRetries; attempt++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return s.fallbackAnalysis(taskTitle), nil
|
||||
default:
|
||||
result, err := analyzer.Analyze(ctx, taskTitle)
|
||||
if err == nil {
|
||||
s.logAnalysis(taskTitle, result, nil)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Check if error is retryable
|
||||
var retryable bool
|
||||
if errors.Is(err, context.DeadlineExceeded) ||
|
||||
errors.Is(err, context.Canceled) ||
|
||||
strings.Contains(err.Error(), "connection refused") ||
|
||||
strings.Contains(err.Error(), "timeout") {
|
||||
retryable = true
|
||||
}
|
||||
|
||||
if !retryable {
|
||||
break
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
backoff := time.Duration(s.config.RetryDelay*(1<<attempt)) * time.Millisecond
|
||||
time.Sleep(backoff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If all attempts failed, use fallback
|
||||
result := s.fallbackAnalysis(taskTitle)
|
||||
s.logAnalysis(taskTitle, result, lastErr)
|
||||
return result, lastErr
|
||||
}
|
||||
|
||||
func (s *AIService) fallbackAnalysis(taskTitle string) *AnalysisResult {
|
||||
result := &AnalysisResult{
|
||||
Priority: 3,
|
||||
Category: inferCategory(taskTitle),
|
||||
TimeEstimate: 30,
|
||||
Summary: taskTitle,
|
||||
}
|
||||
logAnalysis(taskTitle, result, "fallback")
|
||||
return result
|
||||
}
|
||||
|
||||
type TaskAnalyzer interface {
|
||||
Analyze(ctx context.Context, taskTitle string) (*AnalysisResult, error)
|
||||
}
|
||||
|
||||
type OllamaAnalyzer struct {
|
||||
baseURL string
|
||||
model string
|
||||
client *http.Client
|
||||
config *config.AIConfig
|
||||
}
|
||||
|
||||
func (a *OllamaAnalyzer) Analyze(ctx context.Context, taskTitle string) (*AnalysisResult, error) {
|
||||
prompt := fmt.Sprintf(`Analyze this task and provide JSON output with these fields:
|
||||
- priority (1-5, where 1 is highest)
|
||||
- category (work, personal, shopping, etc)
|
||||
- estimate (time estimate in minutes)
|
||||
- related (related tasks or follow-ups)
|
||||
- summary (brief task description)
|
||||
|
||||
Task: %s
|
||||
|
||||
Please ensure the response is valid JSON.`, taskTitle)
|
||||
|
||||
req := ollamaRequest{
|
||||
Model: a.model,
|
||||
Prompt: prompt,
|
||||
Stream: false,
|
||||
MaxTokens: a.config.MaxTokens,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
reqURL := fmt.Sprintf("%s/api/generate", strings.TrimRight(a.baseURL, "/"))
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", reqURL, bytes.NewBuffer(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := a.client.Do(httpReq)
|
||||
if err != nil {
|
||||
// Check if error matches any retryable patterns
|
||||
for _, pattern := range a.config.RetryableErrors {
|
||||
if strings.Contains(err.Error(), pattern) {
|
||||
return nil, fmt.Errorf("retryable error: %w", err)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("failed to make request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusTooManyRequests {
|
||||
return nil, fmt.Errorf("retryable error: rate limit exceeded")
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 500 {
|
||||
return nil, fmt.Errorf("retryable error: server error %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var ollamaResp ollamaResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&ollamaResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
// Clean up response text by removing ANSI escape sequences and markers
|
||||
cleanText := stripAnsiCodes(ollamaResp.Response)
|
||||
|
||||
// Try to extract JSON from the response text
|
||||
var result AnalysisResult
|
||||
jsonStart := strings.Index(cleanText, "{")
|
||||
jsonEnd := strings.LastIndex(cleanText, "}")
|
||||
|
||||
if jsonStart >= 0 && jsonEnd > jsonStart {
|
||||
jsonStr := cleanText[jsonStart : jsonEnd+1]
|
||||
if err := json.Unmarshal([]byte(jsonStr), &result); err != nil {
|
||||
log.Printf("JSON parse error: %v\nContent: %q", err, jsonStr)
|
||||
// Try to extract structured data from text
|
||||
extractedResult := a.extractFromText(cleanText, taskTitle)
|
||||
logAnalysis(taskTitle, extractedResult, "text-extraction")
|
||||
return extractedResult, nil
|
||||
}
|
||||
|
||||
validatedResult := a.validateResult(&result, taskTitle)
|
||||
logAnalysis(taskTitle, validatedResult, "ollama-json")
|
||||
return validatedResult, nil
|
||||
}
|
||||
|
||||
// If no JSON found, try to extract structured data from text
|
||||
extractedResult := a.extractFromText(cleanText, taskTitle)
|
||||
logAnalysis(taskTitle, extractedResult, "text-extraction")
|
||||
return extractedResult, nil
|
||||
}
|
||||
|
||||
// stripAnsiCodes removes ANSI escape sequences and code block markers from text
|
||||
func stripAnsiCodes(text string) string {
|
||||
// Remove ANSI escape sequences
|
||||
ansiRegex := regexp.MustCompile(`\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])`)
|
||||
text = ansiRegex.ReplaceAllString(text, "")
|
||||
|
||||
// Remove code block markers and "json" language identifier
|
||||
text = regexp.MustCompile("(?m)^```json\\s*$").ReplaceAllString(text, "")
|
||||
text = regexp.MustCompile("(?m)^```\\s*$").ReplaceAllString(text, "")
|
||||
|
||||
// Clean up any remaining whitespace
|
||||
text = strings.TrimSpace(text)
|
||||
|
||||
return text
|
||||
}
|
||||
|
||||
func (a *OllamaAnalyzer) extractFromText(text string, taskTitle string) *AnalysisResult {
|
||||
result := &AnalysisResult{
|
||||
Summary: taskTitle,
|
||||
}
|
||||
|
||||
// Extract priority (1-5)
|
||||
if idx := strings.Index(text, "priority"); idx >= 0 {
|
||||
for i := idx; i < len(text) && i < idx+20; i++ {
|
||||
if p := text[i]; p >= '1' && p <= '5' {
|
||||
result.Priority = int(p - '0')
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract category
|
||||
categories := []string{"work", "personal", "shopping", "health", "education"}
|
||||
for _, cat := range categories {
|
||||
if strings.Contains(strings.ToLower(text), cat) {
|
||||
result.Category = cat
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Extract time estimate
|
||||
if idx := strings.Index(text, "estimate"); idx >= 0 {
|
||||
var estimate int
|
||||
fmt.Sscanf(text[idx:], "estimate: %d", &estimate)
|
||||
if estimate > 0 {
|
||||
result.TimeEstimate = estimate
|
||||
}
|
||||
}
|
||||
|
||||
return a.validateResult(result, taskTitle)
|
||||
}
|
||||
|
||||
func (a *OllamaAnalyzer) validateResult(result *AnalysisResult, taskTitle string) *AnalysisResult {
|
||||
if result == nil {
|
||||
result = &AnalysisResult{}
|
||||
}
|
||||
|
||||
// Validate and set defaults
|
||||
if result.Priority < 1 || result.Priority > 5 {
|
||||
// Try to infer priority from keywords
|
||||
priority := 3 // default medium priority
|
||||
lowPriority := []string{"whenever", "someday", "maybe", "if possible", "low"}
|
||||
highPriority := []string{"urgent", "asap", "important", "critical", "due", "deadline"}
|
||||
|
||||
taskLower := strings.ToLower(taskTitle)
|
||||
for _, word := range lowPriority {
|
||||
if strings.Contains(taskLower, word) {
|
||||
priority = 4
|
||||
break
|
||||
}
|
||||
}
|
||||
for _, word := range highPriority {
|
||||
if strings.Contains(taskLower, word) {
|
||||
priority = 2
|
||||
break
|
||||
}
|
||||
}
|
||||
result.Priority = priority
|
||||
}
|
||||
|
||||
if result.TimeEstimate <= 0 {
|
||||
// Set default based on typical task duration
|
||||
switch {
|
||||
case len(taskTitle) < 20:
|
||||
result.TimeEstimate = 15 // Quick tasks
|
||||
case len(taskTitle) < 50:
|
||||
result.TimeEstimate = 30 // Medium tasks
|
||||
default:
|
||||
result.TimeEstimate = 60 // Longer tasks
|
||||
}
|
||||
}
|
||||
|
||||
if result.Category == "" {
|
||||
result.Category = inferCategory(taskTitle)
|
||||
}
|
||||
|
||||
if result.Summary == "" {
|
||||
result.Summary = taskTitle
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (r *AnalysisResult) ToJSON() (string, error) {
|
||||
data, err := json.Marshal(r)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
func inferCategory(taskTitle string) string {
|
||||
taskTitle = strings.ToLower(taskTitle)
|
||||
|
||||
categories := map[string][]string{
|
||||
"work": {"report", "meeting", "project", "email", "presentation", "deadline"},
|
||||
"education": {"study", "learn", "read", "homework", "assignment", "course"},
|
||||
"shopping": {"buy", "purchase", "grocery", "store", "market", "shop"},
|
||||
"health": {"exercise", "gym", "workout", "doctor", "medical", "appointment"},
|
||||
"maintenance": {"clean", "fix", "repair", "maintain", "organize"},
|
||||
"communication": {"call", "email", "message", "contact", "meet"},
|
||||
}
|
||||
|
||||
for category, keywords := range categories {
|
||||
for _, keyword := range keywords {
|
||||
if strings.Contains(taskTitle, keyword) {
|
||||
return category
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "general"
|
||||
}
|
||||
295
src/ai/service_test.go
Normal file
295
src/ai/service_test.go
Normal file
@ -0,0 +1,295 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/your-username/todo/src/config"
|
||||
)
|
||||
|
||||
// MockAnalyzer implements Analyzer interface for testing
|
||||
type MockAnalyzer struct {
|
||||
result *AnalysisResult
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *MockAnalyzer) AnalyzeTask(_ context.Context, _ string) (*AnalysisResult, error) {
|
||||
return m.result, m.err
|
||||
}
|
||||
|
||||
func TestAIService(t *testing.T) {
|
||||
mockResult := &AnalysisResult{
|
||||
Priority: 3,
|
||||
Category: "test",
|
||||
TimeEstimate: 30,
|
||||
Summary: "Test task",
|
||||
}
|
||||
|
||||
t.Run("SingleModel", func(t *testing.T) {
|
||||
service := NewAIService([]ModelConfig{{Name: "mock", Provider: "mock"}})
|
||||
service.analyzers = map[string]Analyzer{
|
||||
"mock": &MockAnalyzer{result: mockResult},
|
||||
}
|
||||
|
||||
result, err := service.AnalyzeTask(context.Background(), "Test task")
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
if result == nil {
|
||||
t.Fatal("Expected analysis result")
|
||||
}
|
||||
if result.Priority != mockResult.Priority {
|
||||
t.Errorf("Expected priority %d, got %d", mockResult.Priority, result.Priority)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Fallback", func(t *testing.T) {
|
||||
service := NewAIService([]ModelConfig{{Name: "failing", Provider: "mock"}})
|
||||
service.analyzers = map[string]Analyzer{
|
||||
"failing": &MockAnalyzer{err: fmt.Errorf("mock error")},
|
||||
}
|
||||
|
||||
result, err := service.AnalyzeTask(context.Background(), "Test task")
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error with fallback, got %v", err)
|
||||
}
|
||||
if result == nil {
|
||||
t.Fatal("Expected fallback analysis result")
|
||||
}
|
||||
if result.Category != "uncategorized" {
|
||||
t.Errorf("Expected uncategorized category, got %s", result.Category)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MultipleModels", func(t *testing.T) {
|
||||
service := NewAIService([]ModelConfig{
|
||||
{Name: "failing1", Provider: "mock"},
|
||||
{Name: "working", Provider: "mock"},
|
||||
})
|
||||
service.analyzers = map[string]Analyzer{
|
||||
"failing1": &MockAnalyzer{err: fmt.Errorf("mock error")},
|
||||
"working": &MockAnalyzer{result: mockResult},
|
||||
}
|
||||
|
||||
result, err := service.AnalyzeTask(context.Background(), "Test task")
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
if result == nil {
|
||||
t.Fatal("Expected analysis result from second model")
|
||||
}
|
||||
if result.Priority != mockResult.Priority {
|
||||
t.Errorf("Expected priority %d, got %d", mockResult.Priority, result.Priority)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestAIService_AnalyzeTask(t *testing.T) {
|
||||
// Mock Ollama server
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/generate" {
|
||||
t.Errorf("Expected to request '/api/generate', got: %s", r.URL.Path)
|
||||
}
|
||||
if r.Method != "POST" {
|
||||
t.Errorf("Expected POST request, got: %s", r.Method)
|
||||
}
|
||||
|
||||
response := ollamaResponse{
|
||||
Response: `{
|
||||
"priority": 2,
|
||||
"category": "work",
|
||||
"time_estimate": 60,
|
||||
"related_tasks": "Follow up with team",
|
||||
"summary": "Project status report for Q2"
|
||||
}`,
|
||||
}
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Create test config
|
||||
cfg := &config.AIConfig{
|
||||
OllamaBaseURL: server.URL,
|
||||
MaxRetries: 2,
|
||||
RetryDelay: 100,
|
||||
Timeout: 5,
|
||||
MaxTokens: 2048,
|
||||
Models: []config.AIModelConfig{
|
||||
{
|
||||
Name: "test-model",
|
||||
Provider: "ollama",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
service := NewAIService(cfg)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
result, err := service.AnalyzeTask(ctx, "Create project status report")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to analyze task: %v", err)
|
||||
}
|
||||
|
||||
if result.Priority != 2 {
|
||||
t.Errorf("Expected priority 2, got %d", result.Priority)
|
||||
}
|
||||
if result.Category != "work" {
|
||||
t.Errorf("Expected category 'work', got %s", result.Category)
|
||||
}
|
||||
if result.TimeEstimate != 60 {
|
||||
t.Errorf("Expected time estimate 60, got %d", result.TimeEstimate)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAIService_Fallback(t *testing.T) {
|
||||
// Create config with invalid URL to force fallback
|
||||
cfg := &config.AIConfig{
|
||||
OllamaBaseURL: "http://invalid-url",
|
||||
MaxRetries: 1,
|
||||
RetryDelay: 100,
|
||||
Timeout: 1,
|
||||
Models: []config.AIModelConfig{
|
||||
{
|
||||
Name: "test-model",
|
||||
Provider: "ollama",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
service := NewAIService(cfg)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
result, err := service.AnalyzeTask(ctx, "Buy groceries")
|
||||
if err == nil {
|
||||
t.Error("Expected error due to invalid URL, got nil")
|
||||
}
|
||||
|
||||
// Should still get a result from fallback
|
||||
if result == nil {
|
||||
t.Fatal("Expected fallback result, got nil")
|
||||
}
|
||||
if result.Category != "shopping" {
|
||||
t.Errorf("Expected category 'shopping', got %s", result.Category)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInferCategory(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"Education", "Learn Python programming", "education"},
|
||||
{"Work", "Complete project report", "work"},
|
||||
{"Shopping", "Buy groceries", "shopping"},
|
||||
{"Communication", "Call mom", "communication"},
|
||||
{"Maintenance", "Fix the leaky faucet", "maintenance"},
|
||||
{"Health", "Go to the gym", "health"},
|
||||
{"Default", "Random task", "general"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := inferCategory(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("inferCategory(%q) = %q; want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnalysisResult_ToJSON(t *testing.T) {
|
||||
result := &AnalysisResult{
|
||||
Priority: 1,
|
||||
Category: "work",
|
||||
TimeEstimate: 30,
|
||||
RelatedTasks: "Follow up",
|
||||
Summary: "Test task",
|
||||
}
|
||||
|
||||
json, err := result.ToJSON()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to convert to JSON: %v", err)
|
||||
}
|
||||
|
||||
expected := `{"priority":1,"category":"work","time_estimate":30,"related_tasks":"Follow up","summary":"Test task"}`
|
||||
if json != expected {
|
||||
t.Errorf("Expected JSON %q, got %q", expected, json)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOllamaAnalyzer_ValidateResult(t *testing.T) {
|
||||
analyzer := &OllamaAnalyzer{
|
||||
model: "test-model",
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input *AnalysisResult
|
||||
taskTitle string
|
||||
wantPriority int
|
||||
wantEstimate int
|
||||
}{
|
||||
{
|
||||
name: "Valid result",
|
||||
input: &AnalysisResult{
|
||||
Priority: 2,
|
||||
TimeEstimate: 45,
|
||||
Category: "work",
|
||||
Summary: "Valid task",
|
||||
},
|
||||
taskTitle: "Test task",
|
||||
wantPriority: 2,
|
||||
wantEstimate: 45,
|
||||
},
|
||||
{
|
||||
name: "Invalid priority",
|
||||
input: &AnalysisResult{
|
||||
Priority: 0,
|
||||
TimeEstimate: 45,
|
||||
Category: "work",
|
||||
},
|
||||
taskTitle: "Test task",
|
||||
wantPriority: 3,
|
||||
wantEstimate: 45,
|
||||
},
|
||||
{
|
||||
name: "Invalid estimate",
|
||||
input: &AnalysisResult{
|
||||
Priority: 2,
|
||||
TimeEstimate: 0,
|
||||
Category: "work",
|
||||
},
|
||||
taskTitle: "Test task",
|
||||
wantPriority: 2,
|
||||
wantEstimate: 30,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := analyzer.validateResult(tt.input, tt.taskTitle)
|
||||
if result.Priority != tt.wantPriority {
|
||||
t.Errorf("Priority = %d; want %d", result.Priority, tt.wantPriority)
|
||||
}
|
||||
if result.TimeEstimate != tt.wantEstimate {
|
||||
t.Errorf("TimeEstimate = %d; want %d", result.TimeEstimate, tt.wantEstimate)
|
||||
}
|
||||
if result.Summary == "" {
|
||||
t.Error("Summary should not be empty")
|
||||
}
|
||||
if result.Category == "" {
|
||||
t.Error("Category should not be empty")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
168
src/config/config.go
Normal file
168
src/config/config.go
Normal file
@ -0,0 +1,168 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Config holds all application configuration
|
||||
type Config struct {
|
||||
Database DatabaseConfig
|
||||
AI AIConfig
|
||||
UI UIConfig
|
||||
}
|
||||
|
||||
type DatabaseConfig struct {
|
||||
Path string
|
||||
MaxOpenConns int
|
||||
MaxIdleConns int
|
||||
ConnMaxLifetime int // in minutes
|
||||
}
|
||||
|
||||
type AIConfig struct {
|
||||
DefaultModel string
|
||||
Models []AIModelConfig
|
||||
MaxRetries int
|
||||
RetryDelay int // in milliseconds
|
||||
OllamaBaseURL string // URL for Ollama API
|
||||
Timeout int // request timeout in seconds
|
||||
RequestsPerMinute int // Rate limit in requests per minute
|
||||
BurstLimit int // Maximum burst size for rate limiter
|
||||
BackgroundEnabled bool // Whether to enable background analysis
|
||||
MaxTokens int // Maximum tokens to generate in responses
|
||||
RetryableErrors []string // List of error strings that should trigger retries
|
||||
}
|
||||
|
||||
type AIModelConfig struct {
|
||||
Name string
|
||||
Provider string
|
||||
APIKey string
|
||||
Priority int // Lower numbers mean higher priority
|
||||
Tags []string // Model capabilities/specialties
|
||||
MaxTokens int // Maximum tokens per request
|
||||
}
|
||||
|
||||
type UIConfig struct {
|
||||
Theme string
|
||||
RefreshInterval int // in seconds
|
||||
IdleStartHour int // Hour to start background analysis (24h format)
|
||||
IdleEndHour int // Hour to end background analysis (24h format)
|
||||
MaxDisplayTasks int
|
||||
}
|
||||
|
||||
// LoadConfig loads configuration from environment variables
|
||||
func LoadConfig() (*Config, error) {
|
||||
cfg := &Config{
|
||||
Database: DatabaseConfig{
|
||||
Path: getEnv("DB_PATH", "todoai.db"),
|
||||
MaxOpenConns: getEnvInt("DB_MAX_OPEN_CONNS", 25),
|
||||
MaxIdleConns: getEnvInt("DB_MAX_IDLE_CONNS", 5),
|
||||
ConnMaxLifetime: getEnvInt("DB_CONN_MAX_LIFETIME", 5),
|
||||
},
|
||||
AI: AIConfig{
|
||||
DefaultModel: getEnv("AI_DEFAULT_MODEL", "gemma3:1b"),
|
||||
MaxRetries: getEnvInt("AI_MAX_RETRIES", 3),
|
||||
RetryDelay: getEnvInt("AI_RETRY_DELAY_MS", 500),
|
||||
OllamaBaseURL: getEnv("AI_OLLAMA_URL", "http://localhost:11434"),
|
||||
Timeout: getEnvInt("AI_TIMEOUT_SECONDS", 10),
|
||||
RequestsPerMinute: getEnvInt("AI_REQUESTS_PER_MINUTE", 10),
|
||||
BurstLimit: getEnvInt("AI_BURST_LIMIT", 3),
|
||||
BackgroundEnabled: getEnvBool("AI_BACKGROUND_ENABLED", true),
|
||||
MaxTokens: getEnvInt("AI_MAX_TOKENS", 2048),
|
||||
RetryableErrors: getEnvArray("AI_RETRYABLE_ERRORS", []string{
|
||||
"connection refused",
|
||||
"timeout",
|
||||
"too many requests",
|
||||
"server error",
|
||||
}),
|
||||
Models: []AIModelConfig{
|
||||
{
|
||||
Name: getEnv("AI_MODEL_NAME", "gemma3:1b"),
|
||||
Provider: getEnv("AI_PROVIDER", "ollama"),
|
||||
APIKey: getEnv("AI_API_KEY", ""),
|
||||
Priority: 1,
|
||||
Tags: getEnvArray("AI_MODEL_TAGS", []string{"general"}),
|
||||
MaxTokens: getEnvInt("AI_MAX_TOKENS", 2048),
|
||||
},
|
||||
{
|
||||
Name: "mistral:7b",
|
||||
Provider: "ollama",
|
||||
Priority: 2,
|
||||
Tags: []string{"general", "fallback"},
|
||||
MaxTokens: 2048,
|
||||
},
|
||||
},
|
||||
},
|
||||
UI: UIConfig{
|
||||
Theme: getEnv("UI_THEME", "default"),
|
||||
RefreshInterval: getEnvInt("UI_REFRESH_INTERVAL", 60),
|
||||
IdleStartHour: getEnvInt("UI_IDLE_START_HOUR", 22),
|
||||
IdleEndHour: getEnvInt("UI_IDLE_END_HOUR", 6),
|
||||
MaxDisplayTasks: getEnvInt("UI_MAX_DISPLAY_TASKS", 100),
|
||||
},
|
||||
}
|
||||
|
||||
return cfg, cfg.validate()
|
||||
}
|
||||
|
||||
func (c *Config) validate() error {
|
||||
if c.Database.Path == "" {
|
||||
return fmt.Errorf("database path cannot be empty")
|
||||
}
|
||||
|
||||
if len(c.AI.Models) == 0 {
|
||||
return fmt.Errorf("at least one AI model must be configured")
|
||||
}
|
||||
|
||||
if c.AI.RequestsPerMinute < 1 {
|
||||
return fmt.Errorf("AI requests per minute must be at least 1")
|
||||
}
|
||||
|
||||
if c.AI.BurstLimit < 1 {
|
||||
return fmt.Errorf("AI burst limit must be at least 1")
|
||||
}
|
||||
|
||||
if c.UI.IdleStartHour < 0 || c.UI.IdleStartHour > 23 {
|
||||
return fmt.Errorf("idle start hour must be between 0 and 23")
|
||||
}
|
||||
|
||||
if c.UI.IdleEndHour < 0 || c.UI.IdleEndHour > 23 {
|
||||
return fmt.Errorf("idle end hour must be between 0 and 23")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getEnv(key, defaultValue string) string {
|
||||
if value, exists := os.LookupEnv(key); exists {
|
||||
return value
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func getEnvInt(key string, defaultValue int) int {
|
||||
if value, exists := os.LookupEnv(key); exists {
|
||||
if intVal, err := strconv.Atoi(value); err == nil {
|
||||
return intVal
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func getEnvArray(key string, defaultValue []string) []string {
|
||||
if value, exists := os.LookupEnv(key); exists {
|
||||
return strings.Split(value, ",")
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func getEnvBool(key string, fallback bool) bool {
|
||||
if value, ok := os.LookupEnv(key); ok {
|
||||
if b, err := strconv.ParseBool(value); err == nil {
|
||||
return b
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
173
src/database/database.go
Normal file
173
src/database/database.go
Normal file
@ -0,0 +1,173 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/your-username/todo/src/config"
|
||||
)
|
||||
|
||||
// Database interface defines the methods required for task storage and analysis
|
||||
type Database interface {
|
||||
Close() error
|
||||
AddTask(ctx context.Context, title string, body string) (int64, error)
|
||||
UpdateTask(ctx context.Context, id int64, title string, body string) error
|
||||
CompleteTask(ctx context.Context, id int) error
|
||||
GetTasks(ctx context.Context, includeCompleted bool) ([]Task, error)
|
||||
GetTaskAnalysis(ctx context.Context, taskID int) (*Analysis, error)
|
||||
StoreAnalysis(ctx context.Context, analysis *Analysis) error
|
||||
}
|
||||
|
||||
// SQLiteDB implements the Database interface
|
||||
type SQLiteDB struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// Task represents a todo task
|
||||
type Task struct {
|
||||
ID int64
|
||||
Title string
|
||||
Body string
|
||||
Completed bool
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// Analysis represents an AI analysis of a task
|
||||
type Analysis struct {
|
||||
TaskID int64
|
||||
Result string
|
||||
AnalyzedAt time.Time
|
||||
}
|
||||
|
||||
func NewDatabase(cfg config.DatabaseConfig) (Database, error) {
|
||||
db, err := sql.Open("sqlite3", cfg.Path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error opening database: %w", err)
|
||||
}
|
||||
|
||||
// Configure connection pool
|
||||
db.SetMaxOpenConns(cfg.MaxOpenConns)
|
||||
db.SetMaxIdleConns(cfg.MaxIdleConns)
|
||||
db.SetConnMaxLifetime(time.Duration(cfg.ConnMaxLifetime) * time.Minute)
|
||||
|
||||
// Initialize schema
|
||||
if err := initSchema(db); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("error initializing schema: %w", err)
|
||||
}
|
||||
|
||||
return &SQLiteDB{db: db}, nil
|
||||
}
|
||||
|
||||
func initSchema(db *sql.DB) error {
|
||||
_, err := db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS tasks (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
title TEXT NOT NULL,
|
||||
body TEXT,
|
||||
completed BOOLEAN DEFAULT FALSE,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS analysis (
|
||||
task_id INTEGER,
|
||||
result TEXT,
|
||||
analyzed_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY(task_id) REFERENCES tasks(id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_tasks_completed ON tasks(completed);
|
||||
CREATE INDEX IF NOT EXISTS idx_analysis_task_id ON analysis(task_id);
|
||||
`)
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *SQLiteDB) Close() error {
|
||||
return d.db.Close()
|
||||
}
|
||||
|
||||
func (d *SQLiteDB) AddTask(ctx context.Context, title string, body string) (int64, error) {
|
||||
result, err := d.db.ExecContext(ctx,
|
||||
"INSERT INTO tasks (title, body) VALUES (?, ?)",
|
||||
title, body)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.LastInsertId()
|
||||
}
|
||||
|
||||
func (d *SQLiteDB) UpdateTask(ctx context.Context, id int64, title string, body string) error {
|
||||
result, err := d.db.ExecContext(ctx,
|
||||
"UPDATE tasks SET title = ?, body = ? WHERE id = ?",
|
||||
title, body, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if rows == 0 {
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *SQLiteDB) CompleteTask(ctx context.Context, id int) error {
|
||||
_, err := d.db.ExecContext(ctx,
|
||||
"UPDATE tasks SET completed = TRUE WHERE id = ?",
|
||||
id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *SQLiteDB) GetTasks(ctx context.Context, includeCompleted bool) ([]Task, error) {
|
||||
query := "SELECT id, title, body, completed, created_at FROM tasks"
|
||||
if !includeCompleted {
|
||||
query += " WHERE completed = FALSE"
|
||||
}
|
||||
query += " ORDER BY created_at DESC"
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tasks []Task
|
||||
for rows.Next() {
|
||||
var t Task
|
||||
err := rows.Scan(&t.ID, &t.Title, &t.Body, &t.Completed, &t.CreatedAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tasks = append(tasks, t)
|
||||
}
|
||||
return tasks, rows.Err()
|
||||
}
|
||||
|
||||
func (d *SQLiteDB) GetTaskAnalysis(ctx context.Context, taskID int) (*Analysis, error) {
|
||||
var analysis Analysis
|
||||
err := d.db.QueryRowContext(ctx, `
|
||||
SELECT task_id, result, analyzed_at
|
||||
FROM analysis
|
||||
WHERE task_id = ?
|
||||
ORDER BY analyzed_at DESC LIMIT 1`,
|
||||
taskID).Scan(&analysis.TaskID, &analysis.Result, &analysis.AnalyzedAt)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &analysis, nil
|
||||
}
|
||||
|
||||
func (d *SQLiteDB) StoreAnalysis(ctx context.Context, analysis *Analysis) error {
|
||||
_, err := d.db.ExecContext(ctx,
|
||||
"INSERT INTO analysis (task_id, result, analyzed_at) VALUES (?, ?, ?)",
|
||||
analysis.TaskID, analysis.Result, analysis.AnalyzedAt)
|
||||
return err
|
||||
}
|
||||
99
src/database/database_test.go
Normal file
99
src/database/database_test.go
Normal file
@ -0,0 +1,99 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSQLiteDB(t *testing.T) {
|
||||
// Create temp database for testing
|
||||
tmpDB := "test.db"
|
||||
db, err := NewSQLiteDB(tmpDB)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test database: %v", err)
|
||||
}
|
||||
defer os.Remove(tmpDB)
|
||||
|
||||
ctx := context.Background()
|
||||
if err := db.Init(ctx); err != nil {
|
||||
t.Fatalf("Failed to initialize database: %v", err)
|
||||
}
|
||||
|
||||
t.Run("AddTask", func(t *testing.T) {
|
||||
id, err := db.AddTask(ctx, "Test task")
|
||||
if err != nil {
|
||||
t.Errorf("Failed to add task: %v", err)
|
||||
}
|
||||
if id <= 0 {
|
||||
t.Error("Expected positive task ID")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetTasks", func(t *testing.T) {
|
||||
tasks, err := db.GetTasks(ctx, false)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to get tasks: %v", err)
|
||||
}
|
||||
if len(tasks) == 0 {
|
||||
t.Error("Expected at least one task")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CompleteTask", func(t *testing.T) {
|
||||
// Add a task first
|
||||
id, _ := db.AddTask(ctx, "Task to complete")
|
||||
|
||||
// Complete it
|
||||
err := db.CompleteTask(ctx, int(id))
|
||||
if err != nil {
|
||||
t.Errorf("Failed to complete task: %v", err)
|
||||
}
|
||||
|
||||
// Verify it's in completed tasks
|
||||
tasks, _ := db.GetTasks(ctx, true)
|
||||
found := false
|
||||
for _, task := range tasks {
|
||||
if task.ID == id {
|
||||
found = true
|
||||
if !task.Completed {
|
||||
t.Error("Task should be marked as completed")
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Completed task not found")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TaskAnalysis", func(t *testing.T) {
|
||||
// Add a task
|
||||
id, _ := db.AddTask(ctx, "Task for analysis")
|
||||
|
||||
// Store analysis
|
||||
analysis := &Analysis{
|
||||
TaskID: id,
|
||||
Result: `{"priority":3,"category":"test"}`,
|
||||
AnalyzedAt: time.Now(),
|
||||
}
|
||||
|
||||
err := db.StoreAnalysis(ctx, analysis)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to store analysis: %v", err)
|
||||
}
|
||||
|
||||
// Retrieve analysis
|
||||
retrieved, err := db.GetTaskAnalysis(ctx, int(id))
|
||||
if err != nil {
|
||||
t.Errorf("Failed to get analysis: %v", err)
|
||||
}
|
||||
if retrieved == nil {
|
||||
t.Error("Expected to retrieve analysis")
|
||||
}
|
||||
if retrieved.Result != analysis.Result {
|
||||
t.Errorf("Expected result %s, got %s", analysis.Result, retrieved.Result)
|
||||
}
|
||||
})
|
||||
}
|
||||
25
src/models/task.go
Normal file
25
src/models/task.go
Normal file
@ -0,0 +1,25 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Task represents a todo item
|
||||
type Task struct {
|
||||
ID int64 `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Completed bool `json:"completed"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at,omitempty"`
|
||||
}
|
||||
|
||||
// TaskAnalysis represents the AI analysis of a task
|
||||
type TaskAnalysis struct {
|
||||
TaskID int64 `json:"task_id"`
|
||||
Priority int `json:"priority"`
|
||||
Category string `json:"category"`
|
||||
TimeEstimate int `json:"time_estimate"`
|
||||
RelatedTasks string `json:"related_tasks"`
|
||||
Summary string `json:"summary"`
|
||||
AnalyzedAt time.Time `json:"analyzed_at"`
|
||||
}
|
||||
552
src/ui/tui.go
Normal file
552
src/ui/tui.go
Normal file
@ -0,0 +1,552 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/bubbles/list"
|
||||
"github.com/charmbracelet/bubbles/spinner"
|
||||
"github.com/charmbracelet/bubbles/textinput"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
|
||||
"github.com/your-username/todo/src/ai"
|
||||
"github.com/your-username/todo/src/database"
|
||||
"github.com/your-username/todo/src/models"
|
||||
)
|
||||
|
||||
var (
|
||||
titleStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("62"))
|
||||
|
||||
completedStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("42")) // Green color
|
||||
|
||||
taskStyle = lipgloss.NewStyle().
|
||||
BorderStyle(lipgloss.NormalBorder()).
|
||||
BorderForeground(lipgloss.Color("240")).
|
||||
Padding(1, 2)
|
||||
|
||||
statusStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("34")) // Blue color
|
||||
|
||||
dimmedStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("240")) // Dimmed gray color
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
list list.Model // Changed from *list.Model to list.Model
|
||||
titleInput textinput.Model
|
||||
bodyInput textinput.Model
|
||||
searchInput textinput.Model
|
||||
spinner spinner.Model
|
||||
db database.Database
|
||||
aiService *ai.AIService
|
||||
state State
|
||||
selected *database.Task
|
||||
errorMsg string
|
||||
filter string
|
||||
currentPage int
|
||||
itemsPerPage int
|
||||
showCompleted bool
|
||||
analysisResults map[int64]*database.Analysis
|
||||
isLoading bool // Track loading state
|
||||
analyzingTasks map[int]bool
|
||||
analysisTimeout time.Duration
|
||||
}
|
||||
|
||||
type State int
|
||||
|
||||
const (
|
||||
StateNormal State = iota
|
||||
StateCreating
|
||||
StateEditing
|
||||
StateSearching
|
||||
StateViewing
|
||||
StateAnalyzing
|
||||
)
|
||||
|
||||
func New(db database.Database, aiService *ai.AIService) Model {
|
||||
ti := textinput.New()
|
||||
ti.Placeholder = "Task title"
|
||||
ti.Focus()
|
||||
|
||||
bi := textinput.New()
|
||||
bi.Placeholder = "Task description (optional)"
|
||||
|
||||
si := textinput.New()
|
||||
si.Placeholder = "Search tasks..."
|
||||
|
||||
delegate := list.NewDefaultDelegate()
|
||||
delegate.Styles.SelectedTitle = delegate.Styles.SelectedTitle.Foreground(lipgloss.Color("62"))
|
||||
delegate.Styles.SelectedDesc = delegate.Styles.SelectedDesc.Foreground(lipgloss.Color("241"))
|
||||
|
||||
taskList := list.New(nil, delegate, 0, 0)
|
||||
taskList.Title = "Todo Tasks"
|
||||
taskList.SetShowTitle(true)
|
||||
taskList.SetFilteringEnabled(true)
|
||||
|
||||
s := spinner.New()
|
||||
s.Spinner = spinner.Dot
|
||||
s.Style = lipgloss.NewStyle().Foreground(lipgloss.Color("205"))
|
||||
|
||||
m := Model{
|
||||
list: taskList,
|
||||
titleInput: ti,
|
||||
bodyInput: bi,
|
||||
searchInput: si,
|
||||
spinner: s,
|
||||
db: db,
|
||||
aiService: aiService,
|
||||
state: StateNormal,
|
||||
itemsPerPage: 10,
|
||||
analysisResults: make(map[int64]*database.Analysis),
|
||||
isLoading: true,
|
||||
analyzingTasks: make(map[int]bool),
|
||||
showCompleted: true, // Show all tasks by default
|
||||
analysisTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func (m Model) Init() tea.Cmd {
|
||||
m.isLoading = true // Set initial loading state
|
||||
return tea.Batch(m.loadTasks, m.spinner.Tick)
|
||||
}
|
||||
|
||||
func (m Model) loadTasks() tea.Msg {
|
||||
dbTasks, err := m.db.GetTasks(context.Background(), m.showCompleted)
|
||||
if err != nil {
|
||||
return errMsg{err}
|
||||
}
|
||||
|
||||
// Convert database.Task to models.Task
|
||||
var items []list.Item
|
||||
for _, t := range dbTasks {
|
||||
items = append(items, taskItem{
|
||||
task: models.Task{
|
||||
ID: t.ID,
|
||||
Title: t.Title,
|
||||
Completed: t.Completed,
|
||||
CreatedAt: t.CreatedAt,
|
||||
},
|
||||
analyzing: m.analyzingTasks[int(t.ID)],
|
||||
})
|
||||
}
|
||||
|
||||
return tasksLoadedMsg{items: items}
|
||||
}
|
||||
|
||||
func (m Model) updateTaskList(tasks []models.Task) {
|
||||
items := make([]list.Item, len(tasks))
|
||||
for i, task := range tasks {
|
||||
items[i] = taskItem{
|
||||
task: task,
|
||||
analyzing: m.analyzingTasks[int(task.ID)],
|
||||
}
|
||||
}
|
||||
m.list.SetItems(items)
|
||||
}
|
||||
|
||||
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
var cmds []tea.Cmd
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case nil:
|
||||
// Initial update, load tasks
|
||||
m.isLoading = true
|
||||
return m, m.loadTasks
|
||||
|
||||
case spinner.TickMsg:
|
||||
var cmd tea.Cmd
|
||||
m.spinner, cmd = m.spinner.Update(msg)
|
||||
return m, cmd
|
||||
|
||||
case tasksLoadedMsg:
|
||||
m.isLoading = false
|
||||
m.list.SetItems(msg.items)
|
||||
if len(msg.items) > 0 {
|
||||
m.list.Select(0)
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case tea.KeyMsg:
|
||||
switch msg.String() {
|
||||
case "q", "ctrl+c":
|
||||
return m, tea.Quit
|
||||
case "n":
|
||||
if m.state == StateNormal {
|
||||
m.state = StateCreating
|
||||
m.titleInput.Focus()
|
||||
return m, textinput.Blink
|
||||
}
|
||||
case "/":
|
||||
if m.state == StateNormal {
|
||||
m.state = StateSearching
|
||||
m.searchInput.Focus()
|
||||
return m, textinput.Blink
|
||||
}
|
||||
case "e":
|
||||
if m.state == StateNormal {
|
||||
if i, ok := m.list.SelectedItem().(taskItem); ok {
|
||||
m.state = StateEditing
|
||||
m.selected = &database.Task{ID: i.task.ID, Title: i.task.Title, Completed: i.task.Completed}
|
||||
m.titleInput.SetValue(i.task.Title)
|
||||
m.titleInput.Focus()
|
||||
return m, textinput.Blink
|
||||
}
|
||||
}
|
||||
case "tab":
|
||||
if m.state == StateCreating || m.state == StateEditing {
|
||||
if m.titleInput.Focused() {
|
||||
m.titleInput.Blur()
|
||||
m.bodyInput.Focus()
|
||||
return m, textinput.Blink
|
||||
} else if m.bodyInput.Focused() {
|
||||
m.bodyInput.Blur()
|
||||
m.titleInput.Focus()
|
||||
return m, textinput.Blink
|
||||
}
|
||||
}
|
||||
m.showCompleted = !m.showCompleted
|
||||
m.filter = "" // Clear filter when toggling completed tasks
|
||||
m.isLoading = true
|
||||
return m, m.loadTasks
|
||||
case "shift+tab":
|
||||
if m.state == StateCreating || m.state == StateEditing {
|
||||
if m.titleInput.Focused() {
|
||||
m.titleInput.Blur()
|
||||
m.bodyInput.Focus()
|
||||
return m, textinput.Blink
|
||||
} else if m.bodyInput.Focused() {
|
||||
m.bodyInput.Blur()
|
||||
m.titleInput.Focus()
|
||||
return m, textinput.Blink
|
||||
}
|
||||
}
|
||||
case "c":
|
||||
if m.state == StateNormal {
|
||||
if i, ok := m.list.SelectedItem().(taskItem); ok {
|
||||
err := m.db.CompleteTask(context.Background(), int(i.task.ID))
|
||||
if err != nil {
|
||||
m.errorMsg = fmt.Sprintf("Failed to complete task: %v", err)
|
||||
}
|
||||
m.showCompleted = true // Show completed tasks after completing one
|
||||
m.filter = "" // Clear filter when showing completed tasks
|
||||
m.isLoading = true
|
||||
return m, m.loadTasks
|
||||
}
|
||||
}
|
||||
case "a":
|
||||
if m.state == StateNormal || m.state == StateViewing {
|
||||
if i, ok := m.list.SelectedItem().(taskItem); ok {
|
||||
m.state = StateAnalyzing
|
||||
m.selected = &database.Task{ID: i.task.ID, Title: i.task.Title, Completed: i.task.Completed}
|
||||
m.analyzingTasks[int(i.task.ID)] = true
|
||||
return m, m.analyzeTask(i.task)
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
case "enter":
|
||||
switch m.state {
|
||||
case StateCreating:
|
||||
if m.titleInput.Focused() && m.titleInput.Value() != "" {
|
||||
m.titleInput.Blur()
|
||||
m.bodyInput.Focus()
|
||||
return m, textinput.Blink
|
||||
} else if m.bodyInput.Focused() {
|
||||
if m.titleInput.Value() != "" {
|
||||
_, err := m.db.AddTask(context.Background(), m.titleInput.Value(), m.bodyInput.Value())
|
||||
if err != nil {
|
||||
m.errorMsg = fmt.Sprintf("Failed to add task: %v", err)
|
||||
}
|
||||
m.titleInput.Reset()
|
||||
m.bodyInput.Reset()
|
||||
m.state = StateNormal
|
||||
m.isLoading = true
|
||||
return m, m.loadTasks
|
||||
}
|
||||
}
|
||||
case StateEditing:
|
||||
if m.titleInput.Focused() && m.titleInput.Value() != "" {
|
||||
m.titleInput.Blur()
|
||||
m.bodyInput.Focus()
|
||||
return m, textinput.Blink
|
||||
} else if m.bodyInput.Focused() && m.selected != nil {
|
||||
err := m.db.UpdateTask(context.Background(), m.selected.ID, m.titleInput.Value(), m.bodyInput.Value())
|
||||
if err != nil {
|
||||
m.errorMsg = fmt.Sprintf("Failed to update task: %v", err)
|
||||
}
|
||||
m.titleInput.Reset()
|
||||
m.bodyInput.Reset()
|
||||
m.selected = nil
|
||||
m.state = StateNormal
|
||||
m.isLoading = true
|
||||
return m, m.loadTasks
|
||||
}
|
||||
case StateSearching:
|
||||
m.filter = m.searchInput.Value()
|
||||
m.searchInput.Reset()
|
||||
m.state = StateNormal
|
||||
m.isLoading = true
|
||||
return m, m.loadTasks
|
||||
case StateNormal:
|
||||
if i, ok := m.list.SelectedItem().(taskItem); ok {
|
||||
m.selected = &database.Task{ID: i.task.ID, Title: i.task.Title, Completed: i.task.Completed}
|
||||
m.state = StateViewing
|
||||
|
||||
// Load existing analysis when viewing task
|
||||
analysis, err := m.db.GetTaskAnalysis(context.Background(), int(i.task.ID))
|
||||
if err == nil && analysis != nil {
|
||||
if m.analysisResults == nil {
|
||||
m.analysisResults = make(map[int64]*database.Analysis)
|
||||
}
|
||||
m.analysisResults[i.task.ID] = analysis
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// If no analysis exists, run the analysis
|
||||
return m, m.analyzeTask(i.task)
|
||||
}
|
||||
}
|
||||
case "esc":
|
||||
if m.state != StateNormal {
|
||||
m.state = StateNormal
|
||||
m.errorMsg = ""
|
||||
m.selected = nil
|
||||
m.titleInput.Reset()
|
||||
m.bodyInput.Reset()
|
||||
if m.state == StateSearching {
|
||||
m.filter = ""
|
||||
}
|
||||
m.isLoading = true
|
||||
return m, m.loadTasks
|
||||
}
|
||||
}
|
||||
|
||||
case tea.WindowSizeMsg:
|
||||
m.list.SetWidth(msg.Width)
|
||||
m.list.SetHeight(msg.Height)
|
||||
return m, nil
|
||||
|
||||
case analysisMsg:
|
||||
if m.state == StateAnalyzing {
|
||||
// Get the currently selected task
|
||||
if i, ok := m.list.SelectedItem().(taskItem); ok {
|
||||
// Create clean, properly formatted JSON
|
||||
analysis := &database.Analysis{
|
||||
TaskID: i.task.ID,
|
||||
Result: fmt.Sprintf(`{"priority":%d,"category":"%s","estimate":%d,"related":"%s","summary":"%s"}`,
|
||||
msg.result.Priority,
|
||||
msg.result.Category,
|
||||
msg.result.TimeEstimate,
|
||||
msg.result.RelatedTasks,
|
||||
msg.result.Summary),
|
||||
AnalyzedAt: time.Now(),
|
||||
}
|
||||
|
||||
if err := m.db.StoreAnalysis(context.Background(), analysis); err != nil {
|
||||
m.errorMsg = fmt.Sprintf("Failed to store analysis: %v", err)
|
||||
} else {
|
||||
if m.analysisResults == nil {
|
||||
m.analysisResults = make(map[int64]*database.Analysis)
|
||||
}
|
||||
m.analysisResults[i.task.ID] = analysis
|
||||
}
|
||||
delete(m.analyzingTasks, int(i.task.ID))
|
||||
}
|
||||
m.state = StateViewing
|
||||
return m, nil
|
||||
}
|
||||
|
||||
case errMsg:
|
||||
m.errorMsg = msg.Error()
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Update the appropriate input based on current state
|
||||
var cmd tea.Cmd
|
||||
switch m.state {
|
||||
case StateAnalyzing:
|
||||
var spinnerCmd tea.Cmd
|
||||
m.spinner, spinnerCmd = m.spinner.Update(msg)
|
||||
cmds = append(cmds, spinnerCmd)
|
||||
case StateCreating, StateEditing:
|
||||
m.titleInput, cmd = m.titleInput.Update(msg)
|
||||
cmds = append(cmds, cmd)
|
||||
m.bodyInput, cmd = m.bodyInput.Update(msg)
|
||||
cmds = append(cmds, cmd)
|
||||
case StateSearching:
|
||||
m.searchInput, cmd = m.searchInput.Update(msg)
|
||||
cmds = append(cmds, cmd)
|
||||
case StateNormal:
|
||||
if !m.isLoading {
|
||||
m.list, cmd = m.list.Update(msg)
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
var spinnerCmd tea.Cmd
|
||||
m.spinner, spinnerCmd = m.spinner.Update(msg)
|
||||
cmds = append(cmds, spinnerCmd)
|
||||
}
|
||||
|
||||
return m, tea.Batch(cmds...)
|
||||
}
|
||||
|
||||
func (m Model) View() string {
|
||||
var b strings.Builder
|
||||
|
||||
if m.errorMsg != "" {
|
||||
b.WriteString(fmt.Sprintf("Error: %s\n\n", m.errorMsg))
|
||||
}
|
||||
|
||||
switch m.state {
|
||||
case StateCreating:
|
||||
b.WriteString("Add new task:\n\n")
|
||||
b.WriteString("Title:\n")
|
||||
b.WriteString(m.titleInput.View())
|
||||
b.WriteString("\n\nDescription:\n")
|
||||
b.WriteString(m.bodyInput.View())
|
||||
b.WriteString("\n\n(Tab/Shift+Tab to switch fields, Enter to move to description or save, ESC to cancel)")
|
||||
|
||||
case StateEditing:
|
||||
b.WriteString("Edit task:\n\n")
|
||||
b.WriteString("Title:\n")
|
||||
b.WriteString(m.titleInput.View())
|
||||
b.WriteString("\n\nDescription:\n")
|
||||
b.WriteString(m.bodyInput.View())
|
||||
b.WriteString("\n\n(Tab/Shift+Tab to switch fields, Enter to move to description or save, ESC to cancel)")
|
||||
|
||||
case StateSearching:
|
||||
b.WriteString("Search tasks:\n\n")
|
||||
b.WriteString(m.searchInput.View())
|
||||
b.WriteString("\n\n(ESC to cancel)")
|
||||
|
||||
case StateViewing:
|
||||
if m.selected != nil {
|
||||
b.WriteString(m.renderTaskDetail())
|
||||
return b.String()
|
||||
}
|
||||
|
||||
case StateAnalyzing:
|
||||
b.WriteString(fmt.Sprintf("\n\n%s Analyzing task...\n", m.spinner.View()))
|
||||
return taskStyle.Render(b.String())
|
||||
|
||||
default:
|
||||
header := fmt.Sprintf(
|
||||
"%s\n\nPress 'n' for new task, 'e' to edit task, '/' to search, TAB to toggle completed, 'c' to complete task\n",
|
||||
titleStyle.Render("Your Tasks"),
|
||||
)
|
||||
b.WriteString(header)
|
||||
if m.isLoading {
|
||||
b.WriteString(fmt.Sprintf("\n%s Loading tasks...\n", m.spinner.View()))
|
||||
} else {
|
||||
b.WriteString(m.list.View())
|
||||
}
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (m Model) renderTaskDetail() string {
|
||||
var b strings.Builder
|
||||
|
||||
if m.selected == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
title := m.selected.Title
|
||||
if m.selected.Completed {
|
||||
title = completedStyle.Render(title)
|
||||
}
|
||||
|
||||
b.WriteString(fmt.Sprintf("Title: %s\n\n", title))
|
||||
|
||||
// Add analysis results if available
|
||||
if analysis, ok := m.analysisResults[m.selected.ID]; ok {
|
||||
var result ai.AnalysisResult
|
||||
if err := json.Unmarshal([]byte(analysis.Result), &result); err == nil {
|
||||
b.WriteString("\nAnalysis Results:\n")
|
||||
b.WriteString(fmt.Sprintf("Priority: %d/5\n", result.Priority))
|
||||
b.WriteString(fmt.Sprintf("Category: %s\n", result.Category))
|
||||
b.WriteString(fmt.Sprintf("Time Estimate: %d minutes\n", result.TimeEstimate))
|
||||
if result.RelatedTasks != "" {
|
||||
b.WriteString(fmt.Sprintf("Related Tasks: %s\n", result.RelatedTasks))
|
||||
}
|
||||
if result.Summary != "" && result.Summary != title {
|
||||
b.WriteString(fmt.Sprintf("Summary: %s\n", result.Summary))
|
||||
}
|
||||
b.WriteString(fmt.Sprintf("\nAnalyzed at: %s", analysis.AnalyzedAt.Format("2006-01-02 15:04:05")))
|
||||
}
|
||||
}
|
||||
|
||||
b.WriteString("\n\nPress 'a' to analyze task, ESC to return")
|
||||
return taskStyle.Render(b.String())
|
||||
}
|
||||
|
||||
func (m Model) analyzeTask(task models.Task) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
// Create a timeout channel
|
||||
done := make(chan struct{})
|
||||
var result *ai.AnalysisResult
|
||||
var err error
|
||||
|
||||
// Run analysis in a goroutine
|
||||
go func() {
|
||||
result, err = m.aiService.AnalyzeTask(context.Background(), task.Title)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Wait for either completion or timeout
|
||||
select {
|
||||
case <-done:
|
||||
if err != nil {
|
||||
return errMsg{err}
|
||||
}
|
||||
return analysisMsg{result}
|
||||
case <-time.After(m.analysisTimeout):
|
||||
return analysisMsg{&ai.AnalysisResult{
|
||||
Priority: 3,
|
||||
Category: "task",
|
||||
TimeEstimate: 30,
|
||||
RelatedTasks: "",
|
||||
Summary: "Analysis timed out for: " + task.Title,
|
||||
}}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type taskItem struct {
|
||||
task models.Task
|
||||
analyzing bool
|
||||
}
|
||||
|
||||
func (i taskItem) Title() string {
|
||||
status := "[ ]"
|
||||
if i.task.Completed {
|
||||
status = "[x]"
|
||||
}
|
||||
|
||||
aiStatus := ""
|
||||
if i.analyzing {
|
||||
aiStatus = " (analyzing...)"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s %s%s", status, i.task.Title, aiStatus)
|
||||
}
|
||||
|
||||
func (i taskItem) Description() string { return "" }
|
||||
func (i taskItem) FilterValue() string { return i.task.Title }
|
||||
|
||||
type tasksLoadedMsg struct {
|
||||
items []list.Item
|
||||
}
|
||||
|
||||
type analysisMsg struct {
|
||||
result *ai.AnalysisResult
|
||||
}
|
||||
|
||||
type errMsg struct{ error }
|
||||
342
src/ui/tui_test.go
Normal file
342
src/ui/tui_test.go
Normal file
@ -0,0 +1,342 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
|
||||
"github.com/your-username/todo/src/ai"
|
||||
"github.com/your-username/todo/src/database"
|
||||
)
|
||||
|
||||
// MockDatabase implements Database interface for testing
|
||||
type MockDatabase struct {
|
||||
tasks []database.Task
|
||||
analyses map[int64]*database.Analysis
|
||||
lastID int64
|
||||
}
|
||||
|
||||
func NewMockDB() *MockDatabase {
|
||||
return &MockDatabase{
|
||||
tasks: make([]database.Task, 0),
|
||||
analyses: make(map[int64]*database.Analysis),
|
||||
lastID: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockDatabase) Reset() {
|
||||
m.tasks = make([]database.Task, 0)
|
||||
m.analyses = make(map[int64]*database.Analysis)
|
||||
m.lastID = 0
|
||||
}
|
||||
|
||||
func (m *MockDatabase) AddTask(_ context.Context, title string, body string) (int64, error) {
|
||||
m.lastID++
|
||||
task := database.Task{
|
||||
ID: m.lastID,
|
||||
Title: title,
|
||||
Body: body,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
m.tasks = append(m.tasks, task)
|
||||
return task.ID, nil
|
||||
}
|
||||
|
||||
func (m *MockDatabase) GetTasks(_ context.Context, includeCompleted bool) ([]database.Task, error) {
|
||||
var result []database.Task
|
||||
for _, t := range m.tasks {
|
||||
if includeCompleted || !t.Completed {
|
||||
result = append(result, t)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *MockDatabase) CompleteTask(_ context.Context, id int) error {
|
||||
for i := range m.tasks {
|
||||
if m.tasks[i].ID == int64(id) {
|
||||
m.tasks[i].Completed = true
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (m *MockDatabase) UpdateTask(_ context.Context, id int64, title string, body string) error {
|
||||
for i := range m.tasks {
|
||||
if m.tasks[i].ID == id {
|
||||
m.tasks[i].Title = title
|
||||
m.tasks[i].Body = body
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (m *MockDatabase) GetTaskAnalysis(_ context.Context, taskID int) (*database.Analysis, error) {
|
||||
return m.analyses[int64(taskID)], nil
|
||||
}
|
||||
|
||||
func (m *MockDatabase) StoreAnalysis(_ context.Context, analysis *database.Analysis) error {
|
||||
m.analyses[analysis.TaskID] = analysis
|
||||
return nil
|
||||
}
|
||||
|
||||
// MockAIService for testing
|
||||
type MockAIService struct {
|
||||
result *ai.AnalysisResult
|
||||
}
|
||||
|
||||
func NewMockAIService() *ai.AIService {
|
||||
return ai.NewAIService([]ai.ModelConfig{
|
||||
{
|
||||
Name: "mock",
|
||||
Provider: "mock",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
type MockAnalyzer struct {
|
||||
result *ai.AnalysisResult
|
||||
}
|
||||
|
||||
func (m *MockAnalyzer) AnalyzeTask(_ context.Context, _ string) (*ai.AnalysisResult, error) {
|
||||
return m.result, nil
|
||||
}
|
||||
|
||||
func TestTUI(t *testing.T) {
|
||||
mockDB := NewMockDB()
|
||||
mockAI := NewMockAIService()
|
||||
|
||||
t.Run("InitialState", func(t *testing.T) {
|
||||
mockDB.Reset()
|
||||
model := New(mockDB, mockAI)
|
||||
if model.state != StateNormal {
|
||||
t.Errorf("Expected initial state to be StateNormal")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("AddTask", func(t *testing.T) {
|
||||
mockDB.Reset()
|
||||
model := New(mockDB, mockAI)
|
||||
|
||||
// Simulate pressing 'n' to enter creation mode
|
||||
updatedModel, cmd := model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("n")})
|
||||
model = updatedModel.(Model)
|
||||
|
||||
if model.state != StateCreating {
|
||||
t.Error("Expected state to be StateCreating after pressing 'n'")
|
||||
}
|
||||
|
||||
// Simulate typing task title
|
||||
model.titleInput.SetValue("Test task")
|
||||
|
||||
// Simulate pressing enter to create task
|
||||
updatedModel, cmd = model.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||
model = updatedModel.(Model)
|
||||
|
||||
// Execute any pending commands
|
||||
if cmd != nil {
|
||||
msg := cmd()
|
||||
updatedModel, _ = model.Update(msg)
|
||||
model = updatedModel.(Model)
|
||||
}
|
||||
|
||||
// Verify task was added
|
||||
tasks, _ := mockDB.GetTasks(context.Background(), false)
|
||||
if len(tasks) != 1 {
|
||||
t.Error("Expected one task to be added")
|
||||
}
|
||||
if len(tasks) > 0 && tasks[0].Title != "Test task" {
|
||||
t.Errorf("Expected task title 'Test task', got '%s'", tasks[0].Title)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CompleteTask", func(t *testing.T) {
|
||||
mockDB.Reset()
|
||||
model := New(mockDB, mockAI)
|
||||
|
||||
// Add a task first
|
||||
id, _ := mockDB.AddTask(context.Background(), "Task to complete", "")
|
||||
|
||||
// Reload task list and execute the command
|
||||
updatedModel, cmd := model.Update(nil)
|
||||
model = updatedModel.(Model)
|
||||
if cmd != nil {
|
||||
msg := cmd()
|
||||
updatedModel, cmd = model.Update(msg)
|
||||
model = updatedModel.(Model)
|
||||
}
|
||||
|
||||
// Verify task is in list
|
||||
items := model.list.Items()
|
||||
if len(items) != 1 {
|
||||
t.Errorf("Expected 1 task in list before completion, got %d", len(items))
|
||||
}
|
||||
|
||||
// Select first task
|
||||
model.list.Select(0)
|
||||
|
||||
// Verify correct task is selected
|
||||
if i, ok := model.list.SelectedItem().(taskItem); ok {
|
||||
if i.Task.ID != id {
|
||||
t.Errorf("Expected selected task ID %d, got %d", id, i.Task.ID)
|
||||
}
|
||||
} else {
|
||||
t.Error("Could not get selected task")
|
||||
}
|
||||
|
||||
// Simulate pressing 'c' to complete
|
||||
updatedModel, cmd = model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("c")})
|
||||
model = updatedModel.(Model)
|
||||
|
||||
// Execute any pending commands
|
||||
if cmd != nil {
|
||||
msg := cmd()
|
||||
updatedModel, _ = model.Update(msg)
|
||||
model = updatedModel.(Model)
|
||||
}
|
||||
|
||||
// Verify task was completed
|
||||
tasks, _ := mockDB.GetTasks(context.Background(), true)
|
||||
found := false
|
||||
for _, task := range tasks {
|
||||
if task.ID == id && task.Completed {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected task to be marked as completed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SearchTasks", func(t *testing.T) {
|
||||
mockDB.Reset()
|
||||
model := New(mockDB, mockAI)
|
||||
|
||||
// Add test tasks
|
||||
mockDB.AddTask(context.Background(), "Test task 1", "")
|
||||
mockDB.AddTask(context.Background(), "Test task 2", "")
|
||||
mockDB.AddTask(context.Background(), "Different task", "")
|
||||
|
||||
// Reload task list and execute the command
|
||||
updatedModel, cmd := model.Update(nil)
|
||||
model = updatedModel.(Model)
|
||||
if cmd != nil {
|
||||
msg := cmd()
|
||||
updatedModel, cmd = model.Update(msg)
|
||||
model = updatedModel.(Model)
|
||||
}
|
||||
|
||||
// Verify tasks are in list
|
||||
items := model.list.Items()
|
||||
if len(items) != 3 {
|
||||
t.Errorf("Expected 3 tasks in list before search, got %d", len(items))
|
||||
}
|
||||
|
||||
// Simulate pressing '/' to enter search mode
|
||||
updatedModel, cmd = model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("/")})
|
||||
model = updatedModel.(Model)
|
||||
|
||||
if model.state != StateSearching {
|
||||
t.Error("Expected state to be StateSearching")
|
||||
}
|
||||
|
||||
// Set search filter
|
||||
model.searchInput.SetValue("Test")
|
||||
updatedModel, cmd = model.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||
model = updatedModel.(Model)
|
||||
|
||||
// Execute any pending commands
|
||||
if cmd != nil {
|
||||
msg := cmd()
|
||||
updatedModel, _ = model.Update(msg)
|
||||
model = updatedModel.(Model)
|
||||
}
|
||||
|
||||
// Verify filtered tasks
|
||||
items = model.list.Items()
|
||||
if len(items) != 2 {
|
||||
t.Errorf("Expected 2 filtered tasks, got %d", len(items))
|
||||
for _, item := range items {
|
||||
t.Logf("Task: %s", item.FilterValue())
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EditTask", func(t *testing.T) {
|
||||
mockDB.Reset()
|
||||
model := New(mockDB, mockAI)
|
||||
|
||||
// Add a task first
|
||||
id, _ := mockDB.AddTask(context.Background(), "Task to edit", "Original description")
|
||||
|
||||
// Reload task list and execute the command
|
||||
updatedModel, cmd := model.Update(nil)
|
||||
model = updatedModel.(Model)
|
||||
if cmd != nil {
|
||||
msg := cmd()
|
||||
updatedModel, cmd = model.Update(msg)
|
||||
model = updatedModel.(Model)
|
||||
}
|
||||
|
||||
// Select first task
|
||||
model.list.Select(0)
|
||||
|
||||
// Verify correct task is selected
|
||||
if i, ok := model.list.SelectedItem().(taskItem); ok {
|
||||
if i.Task.ID != id {
|
||||
t.Errorf("Expected selected task ID %d, got %d", id, i.Task.ID)
|
||||
}
|
||||
} else {
|
||||
t.Error("Could not get selected task")
|
||||
}
|
||||
|
||||
// Simulate pressing 'e' to edit
|
||||
updatedModel, cmd = model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("e")})
|
||||
model = updatedModel.(Model)
|
||||
|
||||
if model.state != StateEditing {
|
||||
t.Error("Expected state to be StateEditing")
|
||||
}
|
||||
|
||||
// Set new values
|
||||
model.titleInput.SetValue("Updated task")
|
||||
model.bodyInput.SetValue("Updated description")
|
||||
|
||||
// Simulate pressing enter to move to body input
|
||||
updatedModel, cmd = model.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||
model = updatedModel.(Model)
|
||||
|
||||
// Simulate pressing enter to save
|
||||
updatedModel, cmd = model.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||
model = updatedModel.(Model)
|
||||
|
||||
// Execute any pending commands
|
||||
if cmd != nil {
|
||||
msg := cmd()
|
||||
updatedModel, _ = model.Update(msg)
|
||||
model = updatedModel.(Model)
|
||||
}
|
||||
|
||||
// Verify task was updated
|
||||
tasks, _ := mockDB.GetTasks(context.Background(), false)
|
||||
found := false
|
||||
for _, task := range tasks {
|
||||
if task.ID == id {
|
||||
if task.Title != "Updated task" || task.Body != "Updated description" {
|
||||
t.Errorf("Expected task to be updated with new title and description")
|
||||
}
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Could not find updated task")
|
||||
}
|
||||
})
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user