From 41c2266ed0de2d44ee7e88cc8bbeda2c546c34e9 Mon Sep 17 00:00:00 2001 From: Lakshman Patel Date: Fri, 19 Jun 2026 18:39:40 +0530 Subject: [PATCH 01/20] refactor(cmd): split chat.go into focused files Extract the Bubble Tea event loop (Update, applyPromptArrowKey) into chat_update.go and the tool-registry construction (essential/optional tools, defaultRegistry) into chat_tools.go. chat.go now holds only model construction and lifecycle, dropping from 1501 to 582 LOC. Pure code movement; no behavior or public API changes. --- cmd/chat.go | 923 +-------------------------------------------- cmd/chat_tools.go | 152 ++++++++ cmd/chat_update.go | 804 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 958 insertions(+), 921 deletions(-) create mode 100644 cmd/chat_tools.go create mode 100644 cmd/chat_update.go diff --git a/cmd/chat.go b/cmd/chat.go index 3d996b31..c6c8bb22 100644 --- a/cmd/chat.go +++ b/cmd/chat.go @@ -39,150 +39,14 @@ import ( "github.com/GrayCodeAI/hawk/internal/session" "github.com/GrayCodeAI/hawk/internal/startup" "github.com/GrayCodeAI/hawk/internal/system/staleness" - "github.com/GrayCodeAI/hawk/internal/tool" "github.com/GrayCodeAI/hawk/internal/ui/icons" ) // Types, styles, and model struct are in chat_model.go // Welcome message and config summary helpers are in chat_welcome.go // Slash command handling and helpers are in chat_commands.go - -func essentialTools() []tool.Tool { - // Core tools needed for basic agent operation - always loaded at startup - return []tool.Tool{ - tool.BashTool{}, - tool.FileReadTool{}, - tool.FileWriteTool{}, - tool.FileEditTool{}, - tool.StructuredEditTool{}, - tool.LSTool{}, - tool.GlobTool{}, - tool.GrepTool{}, - tool.WebFetchTool{}, - tool.WebSearchTool{}, - tool.ToolSearchTool{}, - tool.SkillTool{}, - tool.AgentTool{}, - tool.AskUserQuestionTool{}, - tool.TodoWriteTool{}, - tool.TaskOutputTool{}, - tool.TaskStopTool{}, - tool.LSPTool{}, - tool.MultiEditTool{}, - } -} - -func optionalTools() []tool.Tool { - // Specialized tools that can be lazy-loaded on demand - return []tool.Tool{ - tool.EnterPlanModeTool{}, - tool.ExitPlanModeTool{}, - tool.NotebookEditTool{}, - tool.EnterWorktreeTool{}, - tool.ExitWorktreeTool{}, - tool.ListMcpResourcesTool{}, - tool.ReadMcpResourceTool{}, - tool.ConfigTool{}, - tool.BriefTool{}, - tool.TaskCreateTool{}, - tool.TaskGetTool{}, - tool.TaskListTool{}, - tool.TaskUpdateTool{}, - tool.SleepTool{}, - tool.CronCreateTool{}, - tool.CronDeleteTool{}, - tool.CronListTool{}, - tool.VerifyPlanExecutionTool{}, - tool.WorkflowTool{}, - tool.McpAuthTool{}, - tool.DiagnosticsTool{}, - tool.CodeSearchTool{}, - tool.CoreMemoryAppendTool{}, - tool.CoreMemoryReplaceTool{}, - tool.CoreMemoryRethinkTool{}, - tool.DownloadTool{}, - tool.AgenticFetchTool{}, - tool.ImpactTool{}, - tool.GitHistoryTool{}, - tool.CodeGraphTool{}, - tool.NilAwayTool{}, - tool.ReviveTool{}, - tool.MCPLanguageServerTool{}, - tool.SQLTool{}, - } -} - -func defaultRegistry(settings hawkconfig.Settings) (*tool.Registry, error) { - // Load essential tools first for fast startup - tools := essentialTools() - if tool.IsPowerShellAvailable() { - tools = append(tools, tool.PowerShellTool{}) - } - // Detect project-level MCP servers (supply chain attack vector). - // Project .hawk/settings.json can be committed to a repo and define - // arbitrary commands that execute on clone. Gate behind --allow-project-mcp. - projectMCPServers := hawkconfig.ProjectMCPServers() - projectMCPNames := make(map[string]bool, len(projectMCPServers)) - for _, cfg := range projectMCPServers { - if cfg.Name != "" { - projectMCPNames[cfg.Name] = true - } - } - for _, cfg := range settings.MCPServers { - if cfg.Name == "" || cfg.Command == "" { - continue - } - if projectMCPNames[cfg.Name] && !allowProjectMCP { - fmt.Fprintf(os.Stderr, "hawk: skipping project-level MCP server %q (defined in .hawk/settings.json); use --allow-project-mcp to enable\n", cfg.Name) - continue - } - mcpTools, err := tool.LoadMCPTools(context.Background(), cfg.Name, cfg.Command, cfg.Args...) - if err != nil { - continue - } - tools = append(tools, mcpTools...) - } - // Load MCP server tools - for _, cmd := range mcpServers { - parts := strings.Fields(cmd) - if len(parts) == 0 { - continue - } - name := parts[0] - mcpTools, err := tool.LoadMCPTools(context.Background(), name, parts[0], parts[1:]...) - if err != nil { - // MCP server failed to connect — skip silently, will show in /doctor - continue - } - tools = append(tools, mcpTools...) - } - - filtered, err := filterAvailableTools( - tools, - toolsFlagSet, - parseToolListFromCLI(toolsFlag), - parseToolListFromCLI(disallowedToolsFlag), - ) - if err != nil { - return nil, err - } - registry := tool.NewRegistry(filtered...) - - // Lazy-load optional tools in background - go func() { - for _, t := range optionalTools() { - _ = registry.Register(t) - } - }() - - return registry, nil -} - -func allTools() []tool.Tool { - t := essentialTools() - t = append(t, optionalTools()...) - return t -} +// Tool-registry construction (essential/optional tools) is in chat_tools.go +// The Bubble Tea event loop (Update, applyPromptArrowKey) is in chat_update.go func genID() string { b := make([]byte, 8) @@ -567,789 +431,6 @@ func (m chatModel) Init() tea.Cmd { return tea.Batch(cmds...) } -// applyPromptArrowKey handles Up/Down in the prompt: slash menu navigation or input history. -// Returns true when the key was consumed so callers skip textarea/updateInput handling. -func (m *chatModel) applyPromptArrowKey(msg tea.KeyMsg) bool { - if m.uiFocus != focusPrompt || m.configOpen { - return false - } - switch msg.Type { - case tea.KeyUp, tea.KeyDown: - default: - return false - } - sugs := m.slashSuggestionsFor(m.input.Value()) - if len(sugs) > 0 { - switch msg.Type { - case tea.KeyUp: - if m.slashSel <= 0 { - m.slashSel = len(sugs) - 1 - } else { - m.slashSel-- - } - case tea.KeyDown: - m.slashSel = (m.slashSel + 1) % len(sugs) - } - return true - } - switch msg.Type { - case tea.KeyUp: - if len(m.history) > 0 { - if m.historyIdx == len(m.history) { - m.historyDraft = m.input.Value() - } - if m.historyIdx > 0 { - m.historyIdx-- - m.input.SetValue(m.history[m.historyIdx]) - m.input.CursorEnd() - } - } - return true - case tea.KeyDown: - if m.historyIdx < len(m.history)-1 { - m.historyIdx++ - m.input.SetValue(m.history[m.historyIdx]) - m.input.CursorEnd() - } else if m.historyIdx == len(m.history)-1 { - m.historyIdx = len(m.history) - m.input.SetValue(m.historyDraft) - m.input.CursorEnd() - } - return true - } - return false -} - -func (m chatModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { - var cmds []tea.Cmd - - switch msg := msg.(type) { - case tea.MouseMsg: - if m.mouseEnabled() { - if tea.MouseEvent(msg).IsWheel() { - m.trackMousePosition(msg) - cmds = append(cmds, m.applyMouseScroll(msg)) - m.sanitizeInputIfNeeded() - m = m.syncViewportMouseWheel().withSyncedLayout() - if m.syncInputLayout() { - m.updateViewportContent() - } - if focus := m.ensurePromptInputFocus(); focus != nil { - cmds = append(cmds, focus) - } - } else { - // Motion events (?1003): track pointer only — avoid layout/sanitize/focus per move. - m.trackMousePosition(msg) - } - } - return m, tea.Batch(cmds...) - - case autoOpenConfigMsg: - if !m.openConfigOnStart || m.configOpen { - return m, nil - } - m.openConfigOnStart = false - return m.openConfigPanel() - case tea.KeyMsg: - // Ctrl+\ enters native terminal selection mode. Available in every UI - // state (welcome gate, permissions, prompt, scrollback) so users always - // have a way to copy text out of the chat — the alt-screen + - // mouse-tracking combination otherwise breaks native text selection. - if msg.Type == tea.KeyCtrlBackslash { - return m, enterSelectionMode(m.ref, m.copyableTranscript(), m.mouseEnabled()) - } - if isCopyToClipboardKey(msg) { - return m.handleCopyShortcut() - } - if isMouseSequenceLeak(msg) { - if handled, cmd := m.tryScrollFromMouseLeak(msg); handled { - m.sanitizeInputIfNeeded() - if focus := m.ensurePromptInputFocus(); focus != nil { - return m, tea.Batch(cmd, focus) - } - return m, cmd - } - m.sanitizeInputIfNeeded() - if focus := m.ensurePromptInputFocus(); focus != nil { - return m, focus - } - return m, nil - } - if next, cmd, handled := m.handleWelcomeGateKey(msg); handled { - return next, cmd - } - - // Command palette (Ctrl+K) — intercept all input when open - if m.commandPalette != nil && m.commandPalette.IsOpen() { - action, handled := m.commandPalette.Update(msg) - if handled { - if action != "" { - // Execute the selected command - m.commandPalette.Close() - result, _ := m.handleCommand(action) - if cm, ok := result.(chatModel); ok { - m = cm - } - m.viewDirty = true - m.updateViewportContent() - } - return m, nil - } - } - - if m.manualCompacting { - if isCompactCancelKey(msg) { - return m.cancelManualCompact("Compaction cancelled.") - } - if msg.Type == tea.KeyEnter { - return m, nil - } - // Allow typing in the input while compaction runs (Esc cancels). - } - - if m.inScrollbackFocus() { - switch msg.Type { - case tea.KeyTab: - return m.cycleUIFocus() - case tea.KeyEsc: - m.uiFocus = focusPrompt - m.viewDirty = true - return m, m.input.Focus() - } - if scrolled, cmd := m.applyViewportScroll(msg); scrolled { - return m, cmd - } - if m.routeKeyToViewport(msg) { - var cmd tea.Cmd - m.viewport, cmd = m.viewport.Update(msg) - if m.viewport.AtBottom() { - m.autoScroll = true - } else { - m.autoScroll = false - } - return m, cmd - } - return m, nil - } - - if scrolled, cmd := m.applyViewportScroll(msg); scrolled { - return m, cmd - } - - // Permission prompt active — handle y/n - if m.permReq != nil { - switch msg.String() { - case "y", "Y": - m.permReq.Response <- true - m.messages = append(m.messages, displayMsg{role: "system", content: icons.CheckBold() + " Allowed"}) - m.permReq = nil - case "n", "N": - m.permReq.Response <- false - m.messages = append(m.messages, displayMsg{role: "system", content: icons.CloseThick() + " Denied"}) - m.permReq = nil - case "a", "A": - m.permReq.Response <- true - m.session.Perm.Memory.AlwaysAllow(m.permReq.ToolName) - m.messages = append(m.messages, displayMsg{role: "system", content: icons.CheckBold() + " Always allowed: " + m.permReq.ToolName}) - m.permReq = nil - } - m.viewDirty = true - m.updateViewportContent() - return m, nil - } - // AskUser prompt active — Enter submits answer - if m.askReq != nil { - if msg.Type == tea.KeyEnter { - answer := strings.TrimSpace(m.input.Value()) - m.input.Reset() - m.messages = append(m.messages, displayMsg{role: "user", content: answer}) - m.askReq.response <- answer - m.askReq = nil - m.viewDirty = true - m.updateViewportContent() - return m, nil - } - return m, m.updateInput(msg) - } - if m.waiting { - if msg.Type == tea.KeyCtrlC { - // First Ctrl+C cancels stream, second quits - if m.cancel != nil { - m.cancel() - m.cancel = nil - m.streamCancelled = true - m.messages = append(m.messages, displayMsg{role: "system", content: icons.Stop() + " Cancelled."}) - if m.partial.Len() > 0 { - m.messages = append(m.messages, displayMsg{role: "assistant", content: m.partial.String()}) - m.partial.Reset() - } - m.waiting = false - m.input.Focus() - m.viewDirty = true - m.updateViewportContent() - return m, nil - } - m.saveSession() - if m.watcherStop != nil { - m.watcherStop() - } - m.quitting = true - return m, tea.Quit - } - if msg.Type == tea.KeyEsc { - if m.cancel != nil { - m.cancel() - m.cancel = nil - m.streamCancelled = true - m.messages = append(m.messages, displayMsg{role: "system", content: icons.Stop() + " Cancelled."}) - if m.partial.Len() > 0 { - m.messages = append(m.messages, displayMsg{role: "assistant", content: m.partial.String()}) - m.partial.Reset() - } - m.waiting = false - m.input.Focus() - } - m.viewDirty = true - m.updateViewportContent() - return m, nil - } - // Queue message on Enter while agent is working - if msg.Type == tea.KeyEnter { - text := strings.TrimSpace(m.input.Value()) - if text != "" { - m.messageQueue = append(m.messageQueue, text) - m.messages = append(m.messages, displayMsg{role: "system", content: fmt.Sprintf("%s Queued: %s", icons.Mail(), text)}) - m.input.Reset() - m.viewDirty = true - m.updateViewportContent() - } - return m, nil - } - if m.applyPromptArrowKey(msg) { - return m, nil - } - return m, m.updateInput(msg) - } - if m.configOpen { - switch msg.Type { - case tea.KeyCtrlC: - if time.Since(m.lastCtrlC) < 1*time.Second { - m.saveSession() - if m.watcherStop != nil { - m.watcherStop() - } - m.quitting = true - return m, tea.Quit - } - m.lastCtrlC = time.Now() - m.messages = append(m.messages, displayMsg{role: "system", content: quitAgainMsg}) - m.viewDirty = true - m.updateViewportContent() - return m, nil - default: - next, cmd := m.handleConfigKey(msg) - next.viewDirty = true - next.updateViewportContent() - return next, cmd - } - } - switch msg.Type { - case tea.KeyCtrlA: - // Toggle the Agent Status HUD overlay. - m.hudOpen = !m.hudOpen - if m.hudOpen { - m.hudData = m.collectHUDData() - } - m.viewDirty = true - m.updateViewportContent() - return m, nil - case tea.KeyCtrlK: - // Open command palette - if m.commandPalette == nil { - m.commandPalette = NewCommandPalette(m.width) - } - m.commandPalette.Open() - m.viewDirty = true - m.updateViewportContent() - return m, nil - case tea.KeyCtrlN: - models := configModelChoices(m.configModelOptions, false) - if len(models) > 1 { - current := m.session.Model() - idx := 0 - for i, md := range models { - if md == current { - idx = (i + 1) % len(models) - break - } - } - m.session.SetModel(models[idx]) - m.messages = append(m.messages, displayMsg{role: "system", content: fmt.Sprintf("Model → %s", models[idx])}) - } - m.viewDirty = true - m.updateViewportContent() - return m, nil - case tea.KeyCtrlL: - if m.containerEnabled && !m.containerReady { - m.messages = append(m.messages, displayMsg{role: "system", content: "Waiting for sandbox — tiers unlock when container is ready."}) - m.viewDirty = true - m.updateViewportContent() - return m, nil - } - next := nextAutonomyTier(m.session.PermSvc().Autonomy()) - if m.session.PermSvc().Autonomy() == 0 || autonomyTierIndex(m.session.PermSvc().Autonomy()) < 0 { - next = DefaultContainerAutonomy - } - m.session.PermSvc().SetAutonomy(next) - m.invalidateConnStatus() - m.messages = append(m.messages, displayMsg{role: "system", content: formatAutonomyTierMessage(next)}) - m.viewDirty = true - m.updateViewportContent() - return m, nil - case tea.KeyCtrlC: - if time.Since(m.lastCtrlC) < 1*time.Second { - m.saveSession() - if m.watcherStop != nil { - m.watcherStop() - } - m.quitting = true - return m, tea.Quit - } - m.lastCtrlC = time.Now() - m.messages = append(m.messages, displayMsg{role: "system", content: quitAgainMsg}) - m.viewDirty = true - m.updateViewportContent() - return m, nil - case tea.KeyTab: - // Accept ghost text suggestion if active and input is empty - if m.ghostText.Active() && strings.TrimSpace(m.input.Value()) == "" { - accepted := m.ghostText.Accept() - m.input.SetValue(accepted) - m.input.CursorEnd() - return m, nil - } - sugs := m.slashSuggestionsFor(m.input.Value()) - if len(sugs) > 0 { - if m.slashSel < 0 || m.slashSel >= len(sugs) { - m.slashSel = 0 - } - m.input.SetValue(applySlashSuggestion(sugs[m.slashSel])) - m.input.CursorEnd() - return m, nil - } - return m.cycleUIFocus() - case tea.KeyUp, tea.KeyDown: - if m.applyPromptArrowKey(msg) { - return m, nil - } - case tea.KeyEsc: - if len(m.slashSuggestionsFor(m.input.Value())) > 0 { - m.slashSel = 0 - return m, nil - } - case tea.KeyEnter: - return m.submitUserMessage() - } - - case modelsFetchedMsg: - m.configSaving = false - if msg.err != nil { - if m.configOpen { - m.configNotice = sanitizeConfigNotice(hawkconfig.FormatConfigProviderError(msg.provider, msg.err)) - m.viewDirty = true - m.updateViewportContent() - } - return m, nil - } - if len(msg.options) > 0 { - m.configModelOptions = msg.options - if msg.provider != "" { - modelCacheMu.Lock() - modelCache[msg.provider] = msg.options - modelCacheMu.Unlock() - } - if m.configOpen && strings.Contains(m.configNotice, "Loading") { - m.configNotice = "" - } - } else if m.configOpen { - m.configNotice = hawkconfig.CatalogEmptyHint(context.Background()) - } - if m.session != nil && msg.provider != "" { - gw, _ := m.sessionGatewayModel() - if gw == "" { - gw = msg.provider - } - if strings.TrimSpace(gw) == strings.TrimSpace(msg.provider) { - applyLiveModelMetadata(m.session, gw, m.session.Model()) - } - } - m.invalidateConnStatus() - m.viewDirty = true - if m.configOpen { - if m.configTab == configTabModels { - m = m.focusConfigActiveModelSelection() - } - m.updateViewportContent() - } - return m, nil - - case configApplyCredentialsMsg: - next, cmd := m.handleConfigApplyCredentialsMsg(msg) - if m.configOpen { - next.viewDirty = true - next.updateViewportContent() - } - return next, cmd - - case configGatewayRefreshMsg: - next := m.handleConfigGatewayRefreshMsg(msg) - if m.configOpen { - next.viewDirty = true - next.updateViewportContent() - } - return next, nil - - case configRemoveCredentialMsg: - next, cmd := m.handleConfigRemoveCredentialMsg(msg) - if m.configOpen { - next.viewDirty = true - next.updateViewportContent() - } - return next, cmd - - case loopTickMsg: - if !m.waiting { - result, cmd := m.handleCommand(msg.command) - m.viewDirty = true - m.updateViewportContent() - return result, cmd - } - return m, nil - - case streamChunkMsg: - if m.compacting && !m.manualCompacting { - m.compacting = false - m.brailleSpinner.SetLabel(m.spinnerVerb) - } - m.turnHadAssistantOutput = true - m.partial.WriteString(string(msg)) - m.markPartialDirty() - if m.viewDirty { - m.updateViewportContent() - } - return m, nil - - case thinkingMsg: - m.turnSawThinking = true - return m, nil - - case streamRetryMsg: - m.partial.Reset() - m.messages = stripCurrentTurnThinking(m.messages) - m.turnSawThinking = false - m.turnHadAssistantOutput = false - m.turnHadToolActivity = false - m.messages = append(m.messages, displayMsg{role: "system", content: "↻ " + msg.content}) - m.viewDirty = true - return m, nil - - case toolUseMsg: - m.turnHadToolActivity = true - if m.partial.Len() > 0 { - m.messages = append(m.messages, displayMsg{role: "assistant", content: m.partial.String()}) - m.partial.Reset() - } - m.messages = append(m.messages, displayMsg{role: "tool_use", content: msg.name}) - m.toolStartTime = time.Now() - m.viewDirty = true - return m, nil - - case toolResultMsg: - m.turnHadToolActivity = true - m.messages = append(m.messages, displayMsg{role: "tool_result", content: fmt.Sprintf("[%s] %s", msg.name, msg.content)}) - m.viewDirty = true - return m, nil - - case blastRadiusMsg: - m.messages = append(m.messages, displayMsg{role: "system", content: msg.message}) - m.viewDirty = true - return m, nil - - case selectionResumedMsg: - // Returned from enterSelectionMode. The terminal has been - // restored; just trigger a redraw so the viewport reflects the - // state that was visible before selection. - m.viewDirty = true - m.updateViewportContent() - return m, nil - - case permissionAskMsg: - m.permReq = &msg.req - m.messages = append(m.messages, displayMsg{role: "permission", content: msg.req.Summary}) - m.viewDirty = true - return m, nil - - case askUserMsg: - m.askReq = &msg - m.messages = append(m.messages, displayMsg{role: "question", content: icons.HelpCircle() + " " + msg.question}) - m.viewDirty = true - m.input.Focus() - m.input.SetValue("") - return m, nil - - case usageUpdateMsg: - if msg.usage != nil { - m.turnInputTokens += msg.usage.PromptTokens - m.turnOutputTokens += msg.usage.CompletionTokens - m.invalidateConnStatus() - m.viewDirty = true - } - return m, nil - - case compactTickMsg: - if m.manualCompacting { - if m.brailleSpinner != nil { - m.brailleSpinner.Tick() - } - m.viewDirty = true - m.updateViewportContent() - localCmds := []tea.Cmd{compactTickCmd()} - if !m.input.Focused() { - localCmds = append(localCmds, m.input.Focus()) - } - return m, tea.Batch(localCmds...) - } - return m, nil - - case compactDoneMsg: - return m.finishManualCompact(msg) - - case compactStartMsg: - if !m.manualCompacting { - m.compacting = true - m.brailleSpinner.SetLabel("Compacting context") - m.viewDirty = true - } - return m, nil - - case compactMsg: - m.compacting = false - m.brailleSpinner.SetLabel(m.spinnerVerb) - line := fmt.Sprintf( - "Context compacted (%s): ~%s → ~%s tokens", - msg.strategy, - formatHawkTokenCount(msg.tokensBefore), - formatHawkTokenCount(msg.tokensAfter), - ) - m.messages = append(m.messages, displayMsg{role: "system", content: line}) - m.invalidateConnStatus() - m.viewDirty = true - return m, nil - - case streamDoneMsg: - if m.streamCancelled { - m.streamCancelled = false - m.waiting = false - m.cancel = nil - m.toolStartTime = time.Time{} - m.viewDirty = true - return m, nil - } - if m.compacting { - m.compacting = false - m.brailleSpinner.SetLabel(m.spinnerVerb) - } - m.invalidateConnStatus() - m.flushPartialDirty() - if m.partial.Len() > 0 { - content := sanitizeIdentity(m.partial.String()) - m.messages = append(m.messages, displayMsg{role: "assistant", content: content}) - if m.wal != nil { - _ = m.wal.Append(session.Message{Role: "assistant", Content: content}) - } - // Generate ghost text suggestion from AI response - m.ghostText.Suggest(content) - m.partial.Reset() - } else if m.turnSawThinking && !m.turnHadAssistantOutput && !m.turnHadToolActivity { - // Model sent reasoning tokens but no answer — common with reasoning - // models when the provider drops the post-reasoning content. - m.messages = append(m.messages, displayMsg{ - role: "error", - content: friendlyError(fmt.Errorf("error_only_reasoning: model produced reasoning but no answer")), - }) - } - m.turnSawThinking = false - m.turnHadAssistantOutput = false - m.turnHadToolActivity = false - m.waiting = false - m.cancel = nil - m.toolStartTime = time.Time{} - m.viewDirty = true - m.input.Focus() - m.saveSession() - - // Process queued messages - if len(m.messageQueue) > 0 { - nextMsg := m.messageQueue[0] - m.messageQueue = m.messageQueue[1:] - m.messages = append(m.messages, displayMsg{role: "user", content: nextMsg}) - m.session.AddUser(nextMsg) - m.waiting = true - m.autoScroll = true - m.viewDirty = true - m.spinnerVerb = spinnerVerbs[rand.Intn(len(spinnerVerbs))] - m.brailleSpinner.SetLabel(m.spinnerVerb) - m.turnSawThinking = false - m.turnHadAssistantOutput = false - m.turnHadToolActivity = false - m.turnInputTokens = 0 - m.turnOutputTokens = 0 - m.startedAt = time.Time{} - m.partial.Reset() - m.startStream() - } - - return m, nil - - case streamErrMsg: - m.messages = append(m.messages, displayMsg{role: "error", content: friendlyError(msg.err)}) - m.partial.Reset() - m.waiting = false - m.cancel = nil - m.toolStartTime = time.Time{} - m.viewDirty = true - m.input.Focus() - return m, nil - - case blinkTickMsg: - m.blinkClosed = !m.blinkClosed - m.rebuildWelcomeCache(m.blinkClosed) - m.viewDirty = true - cmds = append(cmds, blinkTickCmd()) - return m, tea.Batch(cmds...) - - case spinnerVerbTickMsg: - cmds = append(cmds, spinnerVerbTickCmd()) - if m.waiting && m.partial.Len() == 0 { - m.spinnerVerb = spinnerVerbs[rand.Intn(len(spinnerVerbs))] - m.brailleSpinner.SetLabel(m.spinnerVerb) - m.viewDirty = true - } - return m, tea.Batch(cmds...) - - case tea.WindowSizeMsg: - m.width = msg.Width - m.height = msg.Height - if !m.onWelcomeGate() { - m.input.SetWidth(msg.Width - 4) - } - m.invalidateInputLayoutCache() - m.rebuildWelcomeCache(false) - m.viewDirty = true - m.refreshInputLayoutIfNeeded() - m = m.withSyncedLayout() - - case spinner.TickMsg: - var cmd tea.Cmd - m.spinner, cmd = m.spinner.Update(msg) - if m.waiting && m.partial.Len() == 0 { - m.brailleSpinner.Tick() - // Lazy-init startedAt here (Update path) so the spinner - // line's elapsed timer has a reference point. Kept out of - // the View path so render stays a pure function. - if m.startedAt.IsZero() { - m.startedAt = time.Now() - } - // Lerp the displayed token counters toward the engine's - // actual numbers — also done here, not in View. - m.displayInTok += (float64(m.tokenInputTarget()) - m.displayInTok) * 0.10 - m.displayOutTok += (float64(m.tokenOutputTarget()) - m.displayOutTok) * 0.10 - m.viewDirty = true - } - cmds = append(cmds, cmd) - - case containerStatusMsg: - m.containerStatus = msg.status - m.containerReady = msg.ready - m.containerErr = msg.err - if msg.sandbox != nil { - m.containerSandbox = msg.sandbox - if m.session != nil { - m.session.SetContainerExecutor(msg.sandbox) - } - } - if msg.ready && m.session != nil { - if m.session.PermSvc().Autonomy() == 0 { - m.session.PermSvc().SetAutonomy(DefaultContainerAutonomy) - } - if m.phase == phaseWelcomeGate { - m.sandboxReadyPending = true - } else { - m.messages = append(m.messages, displayMsg{role: "system", content: formatSandboxReadyAutonomyMessage(m.session.PermSvc().Autonomy())}) - } - m.invalidateConnStatus() - } - if msg.err != nil { - // Fall back to host mode so chat still works (container is optional). - m.containerEnabled = false - m.containerReady = false - if m.session != nil { - m.session.SetContainerRequired(false) - m.session.SetContainerExecutor(nil) - } - m.messages = append(m.messages, displayMsg{ - role: "system", - content: "Container unavailable — running on host. " + msg.err.Error(), - }) - m.input.Focus() - } - m.rebuildWelcomeCache(m.blinkClosed) - m.viewDirty = true - m.updateViewportContent() - } - - if !m.waiting && m.uiFocus == focusPrompt { - // Clear ghost text when user starts typing - if m.ghostText.Active() && m.input.Value() != "" { - m.ghostText.Clear() - } - // Vim mode key interception (operates on full textarea value) - if m.vim != nil && m.vim.IsEnabled() { - if keyMsg, ok := msg.(tea.KeyMsg); ok { - text := m.input.Value() - // textarea doesn't expose cursor column; use text length as approximation - cursor := len(text) - newText, newCursor, consumed := m.vim.HandleKey(keyMsg, text, cursor) - if consumed { - if newText != text { - m.input.SetValue(newText) - } - m.input.SetCursor(newCursor) - } - if consumed && m.vim.Mode == VimNormal { - return m, tea.Batch(cmds...) - } - } - } - if shouldForwardToInput(msg) { - cmds = append(cmds, m.updateInput(msg)) - } - } - if m.uiFocus == focusPrompt && !m.input.Focused() { - cmds = append(cmds, m.input.Focus()) - } - - layoutChanged := m.refreshInputLayoutIfNeeded() - if layoutChanged { - m = m.withSyncedLayout() - } - if m.viewDirty || layoutChanged { - m.updateViewportContent() - } - - return m, tea.Batch(cmds...) -} - // autoIndexCodegraph runs codegraph indexing in the background on startup. // Only indexes if .codegraph/ already exists (user has initialized it before). // Uses Sync for incremental updates (only re-indexes changed files). diff --git a/cmd/chat_tools.go b/cmd/chat_tools.go new file mode 100644 index 00000000..f8710650 --- /dev/null +++ b/cmd/chat_tools.go @@ -0,0 +1,152 @@ +package cmd + +import ( + "context" + "fmt" + "os" + "strings" + + hawkconfig "github.com/GrayCodeAI/hawk/internal/config" + "github.com/GrayCodeAI/hawk/internal/tool" +) + +// This file holds the tool-registry construction used by the chat TUI: +// the essential/optional tool sets and the registry builder that wires in +// MCP servers and CLI tool filters. Split out of chat.go for clarity. + +func essentialTools() []tool.Tool { + // Core tools needed for basic agent operation - always loaded at startup + return []tool.Tool{ + tool.BashTool{}, + tool.FileReadTool{}, + tool.FileWriteTool{}, + tool.FileEditTool{}, + tool.StructuredEditTool{}, + tool.LSTool{}, + tool.GlobTool{}, + tool.GrepTool{}, + tool.WebFetchTool{}, + tool.WebSearchTool{}, + tool.ToolSearchTool{}, + tool.SkillTool{}, + tool.AgentTool{}, + tool.AskUserQuestionTool{}, + tool.TodoWriteTool{}, + tool.TaskOutputTool{}, + tool.TaskStopTool{}, + tool.LSPTool{}, + tool.MultiEditTool{}, + } +} + +func optionalTools() []tool.Tool { + // Specialized tools that can be lazy-loaded on demand + return []tool.Tool{ + tool.EnterPlanModeTool{}, + tool.ExitPlanModeTool{}, + tool.NotebookEditTool{}, + tool.EnterWorktreeTool{}, + tool.ExitWorktreeTool{}, + tool.ListMcpResourcesTool{}, + tool.ReadMcpResourceTool{}, + tool.ConfigTool{}, + tool.BriefTool{}, + tool.TaskCreateTool{}, + tool.TaskGetTool{}, + tool.TaskListTool{}, + tool.TaskUpdateTool{}, + tool.SleepTool{}, + tool.CronCreateTool{}, + tool.CronDeleteTool{}, + tool.CronListTool{}, + tool.VerifyPlanExecutionTool{}, + tool.WorkflowTool{}, + tool.McpAuthTool{}, + tool.DiagnosticsTool{}, + tool.CodeSearchTool{}, + tool.CoreMemoryAppendTool{}, + tool.CoreMemoryReplaceTool{}, + tool.CoreMemoryRethinkTool{}, + tool.DownloadTool{}, + tool.AgenticFetchTool{}, + tool.ImpactTool{}, + tool.GitHistoryTool{}, + tool.CodeGraphTool{}, + tool.NilAwayTool{}, + tool.ReviveTool{}, + tool.MCPLanguageServerTool{}, + tool.SQLTool{}, + } +} + +func defaultRegistry(settings hawkconfig.Settings) (*tool.Registry, error) { + // Load essential tools first for fast startup + tools := essentialTools() + if tool.IsPowerShellAvailable() { + tools = append(tools, tool.PowerShellTool{}) + } + // Detect project-level MCP servers (supply chain attack vector). + // Project .hawk/settings.json can be committed to a repo and define + // arbitrary commands that execute on clone. Gate behind --allow-project-mcp. + projectMCPServers := hawkconfig.ProjectMCPServers() + projectMCPNames := make(map[string]bool, len(projectMCPServers)) + for _, cfg := range projectMCPServers { + if cfg.Name != "" { + projectMCPNames[cfg.Name] = true + } + } + for _, cfg := range settings.MCPServers { + if cfg.Name == "" || cfg.Command == "" { + continue + } + if projectMCPNames[cfg.Name] && !allowProjectMCP { + fmt.Fprintf(os.Stderr, "hawk: skipping project-level MCP server %q (defined in .hawk/settings.json); use --allow-project-mcp to enable\n", cfg.Name) + continue + } + mcpTools, err := tool.LoadMCPTools(context.Background(), cfg.Name, cfg.Command, cfg.Args...) + if err != nil { + continue + } + tools = append(tools, mcpTools...) + } + // Load MCP server tools + for _, cmd := range mcpServers { + parts := strings.Fields(cmd) + if len(parts) == 0 { + continue + } + name := parts[0] + mcpTools, err := tool.LoadMCPTools(context.Background(), name, parts[0], parts[1:]...) + if err != nil { + // MCP server failed to connect — skip silently, will show in /doctor + continue + } + tools = append(tools, mcpTools...) + } + + filtered, err := filterAvailableTools( + tools, + toolsFlagSet, + parseToolListFromCLI(toolsFlag), + parseToolListFromCLI(disallowedToolsFlag), + ) + if err != nil { + return nil, err + } + registry := tool.NewRegistry(filtered...) + + // Lazy-load optional tools in background + go func() { + for _, t := range optionalTools() { + _ = registry.Register(t) + } + }() + + return registry, nil +} + +func allTools() []tool.Tool { + t := essentialTools() + t = append(t, optionalTools()...) + return t +} diff --git a/cmd/chat_update.go b/cmd/chat_update.go new file mode 100644 index 00000000..c993ddab --- /dev/null +++ b/cmd/chat_update.go @@ -0,0 +1,804 @@ +package cmd + +import ( + "context" + "fmt" + "math/rand" + "strings" + "time" + + "github.com/charmbracelet/bubbles/spinner" + tea "github.com/charmbracelet/bubbletea" + + hawkconfig "github.com/GrayCodeAI/hawk/internal/config" + "github.com/GrayCodeAI/hawk/internal/session" + "github.com/GrayCodeAI/hawk/internal/ui/icons" +) + +// This file holds the Bubble Tea event loop for the chat TUI: the central +// Update message switch and the prompt arrow-key handler. Split out of +// chat.go so the model construction/lifecycle and the event loop live in +// separate, focused files. + +// applyPromptArrowKey handles Up/Down in the prompt: slash menu navigation or input history. +// Returns true when the key was consumed so callers skip textarea/updateInput handling. +func (m *chatModel) applyPromptArrowKey(msg tea.KeyMsg) bool { + if m.uiFocus != focusPrompt || m.configOpen { + return false + } + switch msg.Type { + case tea.KeyUp, tea.KeyDown: + default: + return false + } + sugs := m.slashSuggestionsFor(m.input.Value()) + if len(sugs) > 0 { + switch msg.Type { + case tea.KeyUp: + if m.slashSel <= 0 { + m.slashSel = len(sugs) - 1 + } else { + m.slashSel-- + } + case tea.KeyDown: + m.slashSel = (m.slashSel + 1) % len(sugs) + } + return true + } + switch msg.Type { + case tea.KeyUp: + if len(m.history) > 0 { + if m.historyIdx == len(m.history) { + m.historyDraft = m.input.Value() + } + if m.historyIdx > 0 { + m.historyIdx-- + m.input.SetValue(m.history[m.historyIdx]) + m.input.CursorEnd() + } + } + return true + case tea.KeyDown: + if m.historyIdx < len(m.history)-1 { + m.historyIdx++ + m.input.SetValue(m.history[m.historyIdx]) + m.input.CursorEnd() + } else if m.historyIdx == len(m.history)-1 { + m.historyIdx = len(m.history) + m.input.SetValue(m.historyDraft) + m.input.CursorEnd() + } + return true + } + return false +} + +func (m chatModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + var cmds []tea.Cmd + + switch msg := msg.(type) { + case tea.MouseMsg: + if m.mouseEnabled() { + if tea.MouseEvent(msg).IsWheel() { + m.trackMousePosition(msg) + cmds = append(cmds, m.applyMouseScroll(msg)) + m.sanitizeInputIfNeeded() + m = m.syncViewportMouseWheel().withSyncedLayout() + if m.syncInputLayout() { + m.updateViewportContent() + } + if focus := m.ensurePromptInputFocus(); focus != nil { + cmds = append(cmds, focus) + } + } else { + // Motion events (?1003): track pointer only — avoid layout/sanitize/focus per move. + m.trackMousePosition(msg) + } + } + return m, tea.Batch(cmds...) + + case autoOpenConfigMsg: + if !m.openConfigOnStart || m.configOpen { + return m, nil + } + m.openConfigOnStart = false + return m.openConfigPanel() + case tea.KeyMsg: + // Ctrl+\ enters native terminal selection mode. Available in every UI + // state (welcome gate, permissions, prompt, scrollback) so users always + // have a way to copy text out of the chat — the alt-screen + + // mouse-tracking combination otherwise breaks native text selection. + if msg.Type == tea.KeyCtrlBackslash { + return m, enterSelectionMode(m.ref, m.copyableTranscript(), m.mouseEnabled()) + } + if isCopyToClipboardKey(msg) { + return m.handleCopyShortcut() + } + if isMouseSequenceLeak(msg) { + if handled, cmd := m.tryScrollFromMouseLeak(msg); handled { + m.sanitizeInputIfNeeded() + if focus := m.ensurePromptInputFocus(); focus != nil { + return m, tea.Batch(cmd, focus) + } + return m, cmd + } + m.sanitizeInputIfNeeded() + if focus := m.ensurePromptInputFocus(); focus != nil { + return m, focus + } + return m, nil + } + if next, cmd, handled := m.handleWelcomeGateKey(msg); handled { + return next, cmd + } + + // Command palette (Ctrl+K) — intercept all input when open + if m.commandPalette != nil && m.commandPalette.IsOpen() { + action, handled := m.commandPalette.Update(msg) + if handled { + if action != "" { + // Execute the selected command + m.commandPalette.Close() + result, _ := m.handleCommand(action) + if cm, ok := result.(chatModel); ok { + m = cm + } + m.viewDirty = true + m.updateViewportContent() + } + return m, nil + } + } + + if m.manualCompacting { + if isCompactCancelKey(msg) { + return m.cancelManualCompact("Compaction cancelled.") + } + if msg.Type == tea.KeyEnter { + return m, nil + } + // Allow typing in the input while compaction runs (Esc cancels). + } + + if m.inScrollbackFocus() { + switch msg.Type { + case tea.KeyTab: + return m.cycleUIFocus() + case tea.KeyEsc: + m.uiFocus = focusPrompt + m.viewDirty = true + return m, m.input.Focus() + } + if scrolled, cmd := m.applyViewportScroll(msg); scrolled { + return m, cmd + } + if m.routeKeyToViewport(msg) { + var cmd tea.Cmd + m.viewport, cmd = m.viewport.Update(msg) + if m.viewport.AtBottom() { + m.autoScroll = true + } else { + m.autoScroll = false + } + return m, cmd + } + return m, nil + } + + if scrolled, cmd := m.applyViewportScroll(msg); scrolled { + return m, cmd + } + + // Permission prompt active — handle y/n + if m.permReq != nil { + switch msg.String() { + case "y", "Y": + m.permReq.Response <- true + m.messages = append(m.messages, displayMsg{role: "system", content: icons.CheckBold() + " Allowed"}) + m.permReq = nil + case "n", "N": + m.permReq.Response <- false + m.messages = append(m.messages, displayMsg{role: "system", content: icons.CloseThick() + " Denied"}) + m.permReq = nil + case "a", "A": + m.permReq.Response <- true + m.session.Perm.Memory.AlwaysAllow(m.permReq.ToolName) + m.messages = append(m.messages, displayMsg{role: "system", content: icons.CheckBold() + " Always allowed: " + m.permReq.ToolName}) + m.permReq = nil + } + m.viewDirty = true + m.updateViewportContent() + return m, nil + } + // AskUser prompt active — Enter submits answer + if m.askReq != nil { + if msg.Type == tea.KeyEnter { + answer := strings.TrimSpace(m.input.Value()) + m.input.Reset() + m.messages = append(m.messages, displayMsg{role: "user", content: answer}) + m.askReq.response <- answer + m.askReq = nil + m.viewDirty = true + m.updateViewportContent() + return m, nil + } + return m, m.updateInput(msg) + } + if m.waiting { + if msg.Type == tea.KeyCtrlC { + // First Ctrl+C cancels stream, second quits + if m.cancel != nil { + m.cancel() + m.cancel = nil + m.streamCancelled = true + m.messages = append(m.messages, displayMsg{role: "system", content: icons.Stop() + " Cancelled."}) + if m.partial.Len() > 0 { + m.messages = append(m.messages, displayMsg{role: "assistant", content: m.partial.String()}) + m.partial.Reset() + } + m.waiting = false + m.input.Focus() + m.viewDirty = true + m.updateViewportContent() + return m, nil + } + m.saveSession() + if m.watcherStop != nil { + m.watcherStop() + } + m.quitting = true + return m, tea.Quit + } + if msg.Type == tea.KeyEsc { + if m.cancel != nil { + m.cancel() + m.cancel = nil + m.streamCancelled = true + m.messages = append(m.messages, displayMsg{role: "system", content: icons.Stop() + " Cancelled."}) + if m.partial.Len() > 0 { + m.messages = append(m.messages, displayMsg{role: "assistant", content: m.partial.String()}) + m.partial.Reset() + } + m.waiting = false + m.input.Focus() + } + m.viewDirty = true + m.updateViewportContent() + return m, nil + } + // Queue message on Enter while agent is working + if msg.Type == tea.KeyEnter { + text := strings.TrimSpace(m.input.Value()) + if text != "" { + m.messageQueue = append(m.messageQueue, text) + m.messages = append(m.messages, displayMsg{role: "system", content: fmt.Sprintf("%s Queued: %s", icons.Mail(), text)}) + m.input.Reset() + m.viewDirty = true + m.updateViewportContent() + } + return m, nil + } + if m.applyPromptArrowKey(msg) { + return m, nil + } + return m, m.updateInput(msg) + } + if m.configOpen { + switch msg.Type { + case tea.KeyCtrlC: + if time.Since(m.lastCtrlC) < 1*time.Second { + m.saveSession() + if m.watcherStop != nil { + m.watcherStop() + } + m.quitting = true + return m, tea.Quit + } + m.lastCtrlC = time.Now() + m.messages = append(m.messages, displayMsg{role: "system", content: quitAgainMsg}) + m.viewDirty = true + m.updateViewportContent() + return m, nil + default: + next, cmd := m.handleConfigKey(msg) + next.viewDirty = true + next.updateViewportContent() + return next, cmd + } + } + switch msg.Type { + case tea.KeyCtrlA: + // Toggle the Agent Status HUD overlay. + m.hudOpen = !m.hudOpen + if m.hudOpen { + m.hudData = m.collectHUDData() + } + m.viewDirty = true + m.updateViewportContent() + return m, nil + case tea.KeyCtrlK: + // Open command palette + if m.commandPalette == nil { + m.commandPalette = NewCommandPalette(m.width) + } + m.commandPalette.Open() + m.viewDirty = true + m.updateViewportContent() + return m, nil + case tea.KeyCtrlN: + models := configModelChoices(m.configModelOptions, false) + if len(models) > 1 { + current := m.session.Model() + idx := 0 + for i, md := range models { + if md == current { + idx = (i + 1) % len(models) + break + } + } + m.session.SetModel(models[idx]) + m.messages = append(m.messages, displayMsg{role: "system", content: fmt.Sprintf("Model → %s", models[idx])}) + } + m.viewDirty = true + m.updateViewportContent() + return m, nil + case tea.KeyCtrlL: + if m.containerEnabled && !m.containerReady { + m.messages = append(m.messages, displayMsg{role: "system", content: "Waiting for sandbox — tiers unlock when container is ready."}) + m.viewDirty = true + m.updateViewportContent() + return m, nil + } + next := nextAutonomyTier(m.session.PermSvc().Autonomy()) + if m.session.PermSvc().Autonomy() == 0 || autonomyTierIndex(m.session.PermSvc().Autonomy()) < 0 { + next = DefaultContainerAutonomy + } + m.session.PermSvc().SetAutonomy(next) + m.invalidateConnStatus() + m.messages = append(m.messages, displayMsg{role: "system", content: formatAutonomyTierMessage(next)}) + m.viewDirty = true + m.updateViewportContent() + return m, nil + case tea.KeyCtrlC: + if time.Since(m.lastCtrlC) < 1*time.Second { + m.saveSession() + if m.watcherStop != nil { + m.watcherStop() + } + m.quitting = true + return m, tea.Quit + } + m.lastCtrlC = time.Now() + m.messages = append(m.messages, displayMsg{role: "system", content: quitAgainMsg}) + m.viewDirty = true + m.updateViewportContent() + return m, nil + case tea.KeyTab: + // Accept ghost text suggestion if active and input is empty + if m.ghostText.Active() && strings.TrimSpace(m.input.Value()) == "" { + accepted := m.ghostText.Accept() + m.input.SetValue(accepted) + m.input.CursorEnd() + return m, nil + } + sugs := m.slashSuggestionsFor(m.input.Value()) + if len(sugs) > 0 { + if m.slashSel < 0 || m.slashSel >= len(sugs) { + m.slashSel = 0 + } + m.input.SetValue(applySlashSuggestion(sugs[m.slashSel])) + m.input.CursorEnd() + return m, nil + } + return m.cycleUIFocus() + case tea.KeyUp, tea.KeyDown: + if m.applyPromptArrowKey(msg) { + return m, nil + } + case tea.KeyEsc: + if len(m.slashSuggestionsFor(m.input.Value())) > 0 { + m.slashSel = 0 + return m, nil + } + case tea.KeyEnter: + return m.submitUserMessage() + } + + case modelsFetchedMsg: + m.configSaving = false + if msg.err != nil { + if m.configOpen { + m.configNotice = sanitizeConfigNotice(hawkconfig.FormatConfigProviderError(msg.provider, msg.err)) + m.viewDirty = true + m.updateViewportContent() + } + return m, nil + } + if len(msg.options) > 0 { + m.configModelOptions = msg.options + if msg.provider != "" { + modelCacheMu.Lock() + modelCache[msg.provider] = msg.options + modelCacheMu.Unlock() + } + if m.configOpen && strings.Contains(m.configNotice, "Loading") { + m.configNotice = "" + } + } else if m.configOpen { + m.configNotice = hawkconfig.CatalogEmptyHint(context.Background()) + } + if m.session != nil && msg.provider != "" { + gw, _ := m.sessionGatewayModel() + if gw == "" { + gw = msg.provider + } + if strings.TrimSpace(gw) == strings.TrimSpace(msg.provider) { + applyLiveModelMetadata(m.session, gw, m.session.Model()) + } + } + m.invalidateConnStatus() + m.viewDirty = true + if m.configOpen { + if m.configTab == configTabModels { + m = m.focusConfigActiveModelSelection() + } + m.updateViewportContent() + } + return m, nil + + case configApplyCredentialsMsg: + next, cmd := m.handleConfigApplyCredentialsMsg(msg) + if m.configOpen { + next.viewDirty = true + next.updateViewportContent() + } + return next, cmd + + case configGatewayRefreshMsg: + next := m.handleConfigGatewayRefreshMsg(msg) + if m.configOpen { + next.viewDirty = true + next.updateViewportContent() + } + return next, nil + + case configRemoveCredentialMsg: + next, cmd := m.handleConfigRemoveCredentialMsg(msg) + if m.configOpen { + next.viewDirty = true + next.updateViewportContent() + } + return next, cmd + + case loopTickMsg: + if !m.waiting { + result, cmd := m.handleCommand(msg.command) + m.viewDirty = true + m.updateViewportContent() + return result, cmd + } + return m, nil + + case streamChunkMsg: + if m.compacting && !m.manualCompacting { + m.compacting = false + m.brailleSpinner.SetLabel(m.spinnerVerb) + } + m.turnHadAssistantOutput = true + m.partial.WriteString(string(msg)) + m.markPartialDirty() + if m.viewDirty { + m.updateViewportContent() + } + return m, nil + + case thinkingMsg: + m.turnSawThinking = true + return m, nil + + case streamRetryMsg: + m.partial.Reset() + m.messages = stripCurrentTurnThinking(m.messages) + m.turnSawThinking = false + m.turnHadAssistantOutput = false + m.turnHadToolActivity = false + m.messages = append(m.messages, displayMsg{role: "system", content: "↻ " + msg.content}) + m.viewDirty = true + return m, nil + + case toolUseMsg: + m.turnHadToolActivity = true + if m.partial.Len() > 0 { + m.messages = append(m.messages, displayMsg{role: "assistant", content: m.partial.String()}) + m.partial.Reset() + } + m.messages = append(m.messages, displayMsg{role: "tool_use", content: msg.name}) + m.toolStartTime = time.Now() + m.viewDirty = true + return m, nil + + case toolResultMsg: + m.turnHadToolActivity = true + m.messages = append(m.messages, displayMsg{role: "tool_result", content: fmt.Sprintf("[%s] %s", msg.name, msg.content)}) + m.viewDirty = true + return m, nil + + case blastRadiusMsg: + m.messages = append(m.messages, displayMsg{role: "system", content: msg.message}) + m.viewDirty = true + return m, nil + + case selectionResumedMsg: + // Returned from enterSelectionMode. The terminal has been + // restored; just trigger a redraw so the viewport reflects the + // state that was visible before selection. + m.viewDirty = true + m.updateViewportContent() + return m, nil + + case permissionAskMsg: + m.permReq = &msg.req + m.messages = append(m.messages, displayMsg{role: "permission", content: msg.req.Summary}) + m.viewDirty = true + return m, nil + + case askUserMsg: + m.askReq = &msg + m.messages = append(m.messages, displayMsg{role: "question", content: icons.HelpCircle() + " " + msg.question}) + m.viewDirty = true + m.input.Focus() + m.input.SetValue("") + return m, nil + + case usageUpdateMsg: + if msg.usage != nil { + m.turnInputTokens += msg.usage.PromptTokens + m.turnOutputTokens += msg.usage.CompletionTokens + m.invalidateConnStatus() + m.viewDirty = true + } + return m, nil + + case compactTickMsg: + if m.manualCompacting { + if m.brailleSpinner != nil { + m.brailleSpinner.Tick() + } + m.viewDirty = true + m.updateViewportContent() + localCmds := []tea.Cmd{compactTickCmd()} + if !m.input.Focused() { + localCmds = append(localCmds, m.input.Focus()) + } + return m, tea.Batch(localCmds...) + } + return m, nil + + case compactDoneMsg: + return m.finishManualCompact(msg) + + case compactStartMsg: + if !m.manualCompacting { + m.compacting = true + m.brailleSpinner.SetLabel("Compacting context") + m.viewDirty = true + } + return m, nil + + case compactMsg: + m.compacting = false + m.brailleSpinner.SetLabel(m.spinnerVerb) + line := fmt.Sprintf( + "Context compacted (%s): ~%s → ~%s tokens", + msg.strategy, + formatHawkTokenCount(msg.tokensBefore), + formatHawkTokenCount(msg.tokensAfter), + ) + m.messages = append(m.messages, displayMsg{role: "system", content: line}) + m.invalidateConnStatus() + m.viewDirty = true + return m, nil + + case streamDoneMsg: + if m.streamCancelled { + m.streamCancelled = false + m.waiting = false + m.cancel = nil + m.toolStartTime = time.Time{} + m.viewDirty = true + return m, nil + } + if m.compacting { + m.compacting = false + m.brailleSpinner.SetLabel(m.spinnerVerb) + } + m.invalidateConnStatus() + m.flushPartialDirty() + if m.partial.Len() > 0 { + content := sanitizeIdentity(m.partial.String()) + m.messages = append(m.messages, displayMsg{role: "assistant", content: content}) + if m.wal != nil { + _ = m.wal.Append(session.Message{Role: "assistant", Content: content}) + } + // Generate ghost text suggestion from AI response + m.ghostText.Suggest(content) + m.partial.Reset() + } else if m.turnSawThinking && !m.turnHadAssistantOutput && !m.turnHadToolActivity { + // Model sent reasoning tokens but no answer — common with reasoning + // models when the provider drops the post-reasoning content. + m.messages = append(m.messages, displayMsg{ + role: "error", + content: friendlyError(fmt.Errorf("error_only_reasoning: model produced reasoning but no answer")), + }) + } + m.turnSawThinking = false + m.turnHadAssistantOutput = false + m.turnHadToolActivity = false + m.waiting = false + m.cancel = nil + m.toolStartTime = time.Time{} + m.viewDirty = true + m.input.Focus() + m.saveSession() + + // Process queued messages + if len(m.messageQueue) > 0 { + nextMsg := m.messageQueue[0] + m.messageQueue = m.messageQueue[1:] + m.messages = append(m.messages, displayMsg{role: "user", content: nextMsg}) + m.session.AddUser(nextMsg) + m.waiting = true + m.autoScroll = true + m.viewDirty = true + m.spinnerVerb = spinnerVerbs[rand.Intn(len(spinnerVerbs))] + m.brailleSpinner.SetLabel(m.spinnerVerb) + m.turnSawThinking = false + m.turnHadAssistantOutput = false + m.turnHadToolActivity = false + m.turnInputTokens = 0 + m.turnOutputTokens = 0 + m.startedAt = time.Time{} + m.partial.Reset() + m.startStream() + } + + return m, nil + + case streamErrMsg: + m.messages = append(m.messages, displayMsg{role: "error", content: friendlyError(msg.err)}) + m.partial.Reset() + m.waiting = false + m.cancel = nil + m.toolStartTime = time.Time{} + m.viewDirty = true + m.input.Focus() + return m, nil + + case blinkTickMsg: + m.blinkClosed = !m.blinkClosed + m.rebuildWelcomeCache(m.blinkClosed) + m.viewDirty = true + cmds = append(cmds, blinkTickCmd()) + return m, tea.Batch(cmds...) + + case spinnerVerbTickMsg: + cmds = append(cmds, spinnerVerbTickCmd()) + if m.waiting && m.partial.Len() == 0 { + m.spinnerVerb = spinnerVerbs[rand.Intn(len(spinnerVerbs))] + m.brailleSpinner.SetLabel(m.spinnerVerb) + m.viewDirty = true + } + return m, tea.Batch(cmds...) + + case tea.WindowSizeMsg: + m.width = msg.Width + m.height = msg.Height + if !m.onWelcomeGate() { + m.input.SetWidth(msg.Width - 4) + } + m.invalidateInputLayoutCache() + m.rebuildWelcomeCache(false) + m.viewDirty = true + m.refreshInputLayoutIfNeeded() + m = m.withSyncedLayout() + + case spinner.TickMsg: + var cmd tea.Cmd + m.spinner, cmd = m.spinner.Update(msg) + if m.waiting && m.partial.Len() == 0 { + m.brailleSpinner.Tick() + // Lazy-init startedAt here (Update path) so the spinner + // line's elapsed timer has a reference point. Kept out of + // the View path so render stays a pure function. + if m.startedAt.IsZero() { + m.startedAt = time.Now() + } + // Lerp the displayed token counters toward the engine's + // actual numbers — also done here, not in View. + m.displayInTok += (float64(m.tokenInputTarget()) - m.displayInTok) * 0.10 + m.displayOutTok += (float64(m.tokenOutputTarget()) - m.displayOutTok) * 0.10 + m.viewDirty = true + } + cmds = append(cmds, cmd) + + case containerStatusMsg: + m.containerStatus = msg.status + m.containerReady = msg.ready + m.containerErr = msg.err + if msg.sandbox != nil { + m.containerSandbox = msg.sandbox + if m.session != nil { + m.session.SetContainerExecutor(msg.sandbox) + } + } + if msg.ready && m.session != nil { + if m.session.PermSvc().Autonomy() == 0 { + m.session.PermSvc().SetAutonomy(DefaultContainerAutonomy) + } + if m.phase == phaseWelcomeGate { + m.sandboxReadyPending = true + } else { + m.messages = append(m.messages, displayMsg{role: "system", content: formatSandboxReadyAutonomyMessage(m.session.PermSvc().Autonomy())}) + } + m.invalidateConnStatus() + } + if msg.err != nil { + // Fall back to host mode so chat still works (container is optional). + m.containerEnabled = false + m.containerReady = false + if m.session != nil { + m.session.SetContainerRequired(false) + m.session.SetContainerExecutor(nil) + } + m.messages = append(m.messages, displayMsg{ + role: "system", + content: "Container unavailable — running on host. " + msg.err.Error(), + }) + m.input.Focus() + } + m.rebuildWelcomeCache(m.blinkClosed) + m.viewDirty = true + m.updateViewportContent() + } + + if !m.waiting && m.uiFocus == focusPrompt { + // Clear ghost text when user starts typing + if m.ghostText.Active() && m.input.Value() != "" { + m.ghostText.Clear() + } + // Vim mode key interception (operates on full textarea value) + if m.vim != nil && m.vim.IsEnabled() { + if keyMsg, ok := msg.(tea.KeyMsg); ok { + text := m.input.Value() + // textarea doesn't expose cursor column; use text length as approximation + cursor := len(text) + newText, newCursor, consumed := m.vim.HandleKey(keyMsg, text, cursor) + if consumed { + if newText != text { + m.input.SetValue(newText) + } + m.input.SetCursor(newCursor) + } + if consumed && m.vim.Mode == VimNormal { + return m, tea.Batch(cmds...) + } + } + } + if shouldForwardToInput(msg) { + cmds = append(cmds, m.updateInput(msg)) + } + } + if m.uiFocus == focusPrompt && !m.input.Focused() { + cmds = append(cmds, m.input.Focus()) + } + + layoutChanged := m.refreshInputLayoutIfNeeded() + if layoutChanged { + m = m.withSyncedLayout() + } + if m.viewDirty || layoutChanged { + m.updateViewportContent() + } + + return m, tea.Batch(cmds...) +} From d23f4e1cc7a5adbe4ee2b0a3f9a5d942f69cc9f1 Mon Sep 17 00:00:00 2001 From: Lakshman Patel Date: Fri, 19 Jun 2026 18:43:26 +0530 Subject: [PATCH 02/20] refactor(review): extract built-in rules from review_bot.go into review_bot_rules.go --- internal/engine/review/review_bot.go | 920 +------------------- internal/engine/review/review_bot_rules.go | 928 +++++++++++++++++++++ 2 files changed, 932 insertions(+), 916 deletions(-) create mode 100644 internal/engine/review/review_bot_rules.go diff --git a/internal/engine/review/review_bot.go b/internal/engine/review/review_bot.go index 68777c1f..67e3097c 100644 --- a/internal/engine/review/review_bot.go +++ b/internal/engine/review/review_bot.go @@ -12,6 +12,10 @@ import ( // ReviewBot is a rule-based code review engine that produces structured // feedback without requiring an LLM call. +// +// The built-in rule set (builtinReviewRules and the rule* constructors) lives +// in review_bot_rules.go; this file holds the engine, formatting, diff parsing, +// and shared helpers. type ReviewBot struct { Rules []ReviewRule Severity string // minimum severity to report: "error", "warning", "info" @@ -324,922 +328,6 @@ func isChangedLine(lineNo int, diff []diff.DiffLine) bool { return false } -// ---------- built-in rules ---------- - -func builtinReviewRules() []ReviewRule { - return []ReviewRule{ - ruleHardcodedSecrets(), - ruleSQLInjection(), - ruleCommandInjection(), - ruleXSS(), - ruleNPlusOneQuery(), - ruleUnboundedAllocation(), - ruleStringConcatInLoop(), - ruleErrorIgnored(), - ruleNilDereferenceRisk(), - ruleUnclosedResources(), - ruleRaceCondition(), - ruleExportedWithoutDocs(), - ruleInconsistentNaming(), - ruleMagicNumbers(), - ruleTestWithoutAssertion(), - ruleSkippedTests(), - ruleTestFileWithoutTests(), - ruleHardcodedIP(), - ruleTODOsInCode(), - ruleEmptyErrorHandler(), - ruleDeferInLoop(), - ruleUnusedParameter(), - } -} - -// --- Security rules --- - -func ruleHardcodedSecrets() ReviewRule { - patterns := []*regexp.Regexp{ - regexp.MustCompile(`(?i)(password|secret|api_key|apikey|token|private_key)\s*[:=]\s*["']` + "[^\"']{8,}" + `["']`), - regexp.MustCompile(`(?i)(aws_access_key_id|aws_secret_access_key)\s*[:=]\s*["']`), - regexp.MustCompile(`(?i)-----BEGIN (RSA |EC )?PRIVATE KEY-----`), - regexp.MustCompile(`ghp_[A-Za-z0-9]{36}`), - regexp.MustCompile(`sk-[A-Za-z0-9]{20,}`), - } - return ReviewRule{ - ID: "SEC001", - Name: "Hardcoded secret detected", - Category: "security", - Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { - var comments []ReviewComment - for i, line := range lines { - if !isChangedLine(i+1, diff) { - continue - } - for _, pat := range patterns { - if pat.MatchString(line) { - comments = append(comments, ReviewComment{ - File: file, - Line: i + 1, - Severity: "error", - Category: "security", - Message: "Hardcoded secret detected", - Suggestion: "Use environment variable instead", - RuleID: "SEC001", - }) - break - } - } - } - return comments - }, - } -} - -func ruleSQLInjection() ReviewRule { - patterns := []*regexp.Regexp{ - regexp.MustCompile(`(?i)(fmt\.Sprintf|".*\+.*")\s*.*?(SELECT|INSERT|UPDATE|DELETE|DROP)\s`), - regexp.MustCompile(`(?i)query\s*[:=].*\+\s*\w`), - regexp.MustCompile(`(?i)(Exec|Query|QueryRow)\s*\(\s*fmt\.Sprintf`), - regexp.MustCompile(`(?i)(Exec|Query|QueryRow)\s*\(\s*"[^"]*"\s*\+`), - } - return ReviewRule{ - ID: "SEC002", - Name: "Potential SQL injection", - Category: "security", - Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { - var comments []ReviewComment - for i, line := range lines { - if !isChangedLine(i+1, diff) { - continue - } - for _, pat := range patterns { - if pat.MatchString(line) { - comments = append(comments, ReviewComment{ - File: file, - Line: i + 1, - Severity: "error", - Category: "security", - Message: "Potential SQL injection via string concatenation", - Suggestion: "Use parameterized queries with placeholder arguments", - RuleID: "SEC002", - }) - break - } - } - } - return comments - }, - } -} - -func ruleCommandInjection() ReviewRule { - pattern := regexp.MustCompile(`(?i)(exec\.Command|os\.system|subprocess\.(call|run|Popen))\s*\(.*\+`) - patternFmt := regexp.MustCompile(`(?i)(exec\.Command|os\.system)\s*\(\s*fmt\.Sprintf`) - return ReviewRule{ - ID: "SEC003", - Name: "Potential command injection", - Category: "security", - Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { - var comments []ReviewComment - for i, line := range lines { - if !isChangedLine(i+1, diff) { - continue - } - if pattern.MatchString(line) || patternFmt.MatchString(line) { - comments = append(comments, ReviewComment{ - File: file, - Line: i + 1, - Severity: "error", - Category: "security", - Message: "Potential command injection via string concatenation", - Suggestion: "Sanitize inputs or use argument lists instead of shell strings", - RuleID: "SEC003", - }) - } - } - return comments - }, - } -} - -func ruleXSS() ReviewRule { - patterns := []*regexp.Regexp{ - regexp.MustCompile(`(?i)innerHTML\s*=`), - regexp.MustCompile(`(?i)document\.write\s*\(`), - regexp.MustCompile(`(?i)fmt\.Fprintf\s*\(\s*w\s*,.*\+`), - regexp.MustCompile(`(?i)template\.HTML\(`), - } - return ReviewRule{ - ID: "SEC004", - Name: "Potential XSS vulnerability", - Category: "security", - Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { - var comments []ReviewComment - for i, line := range lines { - if !isChangedLine(i+1, diff) { - continue - } - for _, pat := range patterns { - if pat.MatchString(line) { - comments = append(comments, ReviewComment{ - File: file, - Line: i + 1, - Severity: "error", - Category: "security", - Message: "Potential cross-site scripting (XSS) vulnerability", - Suggestion: "Sanitize or escape user input before rendering", - RuleID: "SEC004", - }) - break - } - } - } - return comments - }, - } -} - -// --- Performance rules --- - -func ruleNPlusOneQuery() ReviewRule { - queryPat := regexp.MustCompile(`(?i)(\.Query|\.QueryRow|\.Exec)\s*\(`) - loopPat := regexp.MustCompile(`^\s*(for|range)\s`) - return ReviewRule{ - ID: "PERF001", - Name: "Potential N+1 query", - Category: "performance", - Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { - var comments []ReviewComment - inLoop := false - loopStart := 0 - braceDepth := 0 - for i, line := range lines { - if loopPat.MatchString(line) { - inLoop = true - loopStart = i + 1 - braceDepth = 0 - } - if inLoop { - braceDepth += strings.Count(line, "{") - strings.Count(line, "}") - if braceDepth <= 0 && i > loopStart { - inLoop = false - } - } - if inLoop && queryPat.MatchString(line) && isChangedLine(i+1, diff) { - comments = append(comments, ReviewComment{ - File: file, - Line: i + 1, - Severity: "warning", - Category: "performance", - Message: "Database query inside loop (potential N+1 query)", - Suggestion: "Batch queries or use a JOIN to fetch all data at once", - RuleID: "PERF001", - }) - } - } - return comments - }, - } -} - -func ruleUnboundedAllocation() ReviewRule { - pattern := regexp.MustCompile(`make\s*\(\s*\[\]\w+\s*,\s*\w+`) - return ReviewRule{ - ID: "PERF002", - Name: "Potentially unbounded allocation", - Category: "performance", - Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { - var comments []ReviewComment - for i, line := range lines { - if !isChangedLine(i+1, diff) { - continue - } - if pattern.MatchString(line) && !strings.Contains(line, "cap") { - // Check if the size variable might be user-controlled. - if strings.Contains(line, "len(") || strings.Contains(line, "req.") || strings.Contains(line, "input") { - comments = append(comments, ReviewComment{ - File: file, - Line: i + 1, - Severity: "warning", - Category: "performance", - Message: "Potentially unbounded allocation based on external input", - Suggestion: "Add a maximum cap to prevent OOM from large inputs", - RuleID: "PERF002", - }) - } - } - } - return comments - }, - } -} - -func ruleStringConcatInLoop() ReviewRule { - concatPat := regexp.MustCompile(`\w+\s*\+=\s*("|\w)`) - loopPat := regexp.MustCompile(`^\s*(for|range)\s`) - return ReviewRule{ - ID: "PERF003", - Name: "String concatenation in loop", - Category: "performance", - Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { - var comments []ReviewComment - inLoop := false - loopStart := 0 - braceDepth := 0 - for i, line := range lines { - if loopPat.MatchString(line) { - inLoop = true - loopStart = i + 1 - braceDepth = 0 - } - if inLoop { - braceDepth += strings.Count(line, "{") - strings.Count(line, "}") - if braceDepth <= 0 && i > loopStart { - inLoop = false - } - } - if inLoop && concatPat.MatchString(line) && isChangedLine(i+1, diff) { - if strings.Contains(line, "string") || strings.Contains(line, `"`) || strings.Contains(line, "str") || strings.Contains(line, "result") || strings.Contains(line, "output") { - comments = append(comments, ReviewComment{ - File: file, - Line: i + 1, - Severity: "warning", - Category: "performance", - Message: "String concatenation inside loop", - Suggestion: "Use strings.Builder for better performance", - RuleID: "PERF003", - }) - } - } - } - return comments - }, - } -} - -// --- Correctness rules --- - -func ruleErrorIgnored() ReviewRule { - patterns := []*regexp.Regexp{ - regexp.MustCompile(`\w+,\s*_\s*:?=\s*\w+.*\(`), - regexp.MustCompile(`^\s*\w+\.\w+\(.*\)\s*$`), - } - errFuncPat := regexp.MustCompile(`(?i)(write|close|flush|send|remove|delete|create)`) - return ReviewRule{ - ID: "CORR001", - Name: "Error return value ignored", - Category: "correctness", - Language: "go", - Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { - var comments []ReviewComment - for i, line := range lines { - if !isChangedLine(i+1, diff) { - continue - } - // Pattern: val, _ := someFunc() - if patterns[0].MatchString(line) && errFuncPat.MatchString(line) { - comments = append(comments, ReviewComment{ - File: file, - Line: i + 1, - Severity: "warning", - Category: "correctness", - Message: "Error return value ignored", - Suggestion: "if err != nil { return err }", - RuleID: "CORR001", - }) - } - } - return comments - }, - } -} - -func ruleNilDereferenceRisk() ReviewRule { - returnPat := regexp.MustCompile(`return\s+nil\s*,`) - usePat := regexp.MustCompile(`(\w+)\s*,\s*err\s*:?=`) - return ReviewRule{ - ID: "CORR002", - Name: "Nil dereference risk", - Category: "correctness", - Language: "go", - Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { - var comments []ReviewComment - _ = returnPat - for i, line := range lines { - if !isChangedLine(i+1, diff) { - continue - } - if usePat.MatchString(line) { - // Check if next non-blank line uses the variable without nil check. - if i+1 < len(lines) { - nextLine := strings.TrimSpace(lines[i+1]) - varMatch := usePat.FindStringSubmatch(line) - if len(varMatch) > 1 && !strings.Contains(nextLine, "err") && !strings.Contains(nextLine, "nil") && strings.Contains(nextLine, varMatch[1]+".") { - comments = append(comments, ReviewComment{ - File: file, - Line: i + 2, - Severity: "warning", - Category: "correctness", - Message: "Potential nil dereference — value used without checking error", - Suggestion: "Check err != nil before using " + varMatch[1], - RuleID: "CORR002", - }) - } - } - } - } - return comments - }, - } -} - -func ruleUnclosedResources() ReviewRule { - openPat := regexp.MustCompile(`(os\.Open|sql\.Open|net\.Dial|http\.Get)\s*\(`) - deferPat := regexp.MustCompile(`defer\s+\w+\.Close\(\)`) - return ReviewRule{ - ID: "CORR003", - Name: "Unclosed resource", - Category: "correctness", - Language: "go", - Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { - var comments []ReviewComment - for i, line := range lines { - if !isChangedLine(i+1, diff) { - continue - } - if openPat.MatchString(line) { - // Look for defer close within next 5 lines. - found := false - for j := i + 1; j < i+6 && j < len(lines); j++ { - if deferPat.MatchString(lines[j]) || strings.Contains(lines[j], ".Close()") { - found = true - break - } - } - if !found { - comments = append(comments, ReviewComment{ - File: file, - Line: i + 1, - Severity: "warning", - Category: "correctness", - Message: "Resource opened without visible defer Close()", - Suggestion: "Add defer resource.Close() after error check", - RuleID: "CORR003", - }) - } - } - } - return comments - }, - } -} - -func ruleRaceCondition() ReviewRule { - goPat := regexp.MustCompile(`^\s*go\s+\w+`) - sharedPat := regexp.MustCompile(`(shared|global|counter|state)\w*\s*[\+\-]?=`) - return ReviewRule{ - ID: "CORR004", - Name: "Potential race condition", - Category: "correctness", - Language: "go", - Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { - var comments []ReviewComment - hasGoroutine := false - for _, line := range lines { - if goPat.MatchString(line) { - hasGoroutine = true - break - } - } - if !hasGoroutine { - return nil - } - for i, line := range lines { - if !isChangedLine(i+1, diff) { - continue - } - if sharedPat.MatchString(line) { - // Check if there's a mutex lock nearby. - hasMutex := false - for j := maxInt(0, i-5); j < i; j++ { - if strings.Contains(lines[j], "Lock()") || strings.Contains(lines[j], "RLock()") { - hasMutex = true - break - } - } - if !hasMutex { - comments = append(comments, ReviewComment{ - File: file, - Line: i + 1, - Severity: "warning", - Category: "correctness", - Message: "Potential race condition — shared state modified without visible lock", - Suggestion: "Protect with sync.Mutex or use atomic operations", - RuleID: "CORR004", - }) - } - } - } - return comments - }, - } -} - -// --- Style rules --- - -func ruleExportedWithoutDocs() ReviewRule { - exportedFunc := regexp.MustCompile(`^func\s+([A-Z]\w*)`) - exportedMethod := regexp.MustCompile(`^func\s+\(\w+\s+\*?\w+\)\s+([A-Z]\w*)`) - exportedType := regexp.MustCompile(`^type\s+([A-Z]\w*)`) - commentPat := regexp.MustCompile(`^//\s*`) - return ReviewRule{ - ID: "STY001", - Name: "Exported symbol without documentation", - Category: "style", - Language: "go", - Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { - // Skip test files. - if strings.HasSuffix(file, "_test.go") { - return nil - } - var comments []ReviewComment - for i, line := range lines { - if !isChangedLine(i+1, diff) { - continue - } - var name string - if m := exportedMethod.FindStringSubmatch(line); len(m) > 1 { - name = m[1] - } else if m := exportedFunc.FindStringSubmatch(line); len(m) > 1 { - name = m[1] - } else if m := exportedType.FindStringSubmatch(line); len(m) > 1 { - name = m[1] - } - if name == "" { - continue - } - // Check previous line for comment. - hasDoc := false - if i > 0 && commentPat.MatchString(strings.TrimSpace(lines[i-1])) { - hasDoc = true - } - if !hasDoc { - comments = append(comments, ReviewComment{ - File: file, - Line: i + 1, - Severity: "info", - Category: "style", - Message: "Exported function missing documentation", - Suggestion: fmt.Sprintf("Add godoc comment: // %s ...", name), - RuleID: "STY001", - }) - } - } - return comments - }, - } -} - -func ruleInconsistentNaming() ReviewRule { - snakePat := regexp.MustCompile(`\b[a-z]+_[a-z]+\b`) - varDecl := regexp.MustCompile(`(var|:=)\s+`) - return ReviewRule{ - ID: "STY002", - Name: "Inconsistent naming convention", - Category: "style", - Language: "go", - Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { - var comments []ReviewComment - for i, line := range lines { - if !isChangedLine(i+1, diff) { - continue - } - if varDecl.MatchString(line) && snakePat.MatchString(line) { - // Ignore struct tags and strings. - trimmed := removeStrings(line) - if snakePat.MatchString(trimmed) { - comments = append(comments, ReviewComment{ - File: file, - Line: i + 1, - Severity: "info", - Category: "style", - Message: "Snake_case identifier in Go code (use camelCase)", - Suggestion: "Rename to camelCase per Go conventions", - RuleID: "STY002", - }) - } - } - } - return comments - }, - } -} - -func ruleMagicNumbers() ReviewRule { - numPat := regexp.MustCompile(`[^0-9\.]([2-9]\d{2,}|[1-9]\d{3,})([^0-9\.]|$)`) - ignorePat := regexp.MustCompile(`(const|http\.|port|timeout|test|spec|0x|version|v\d)`) - return ReviewRule{ - ID: "STY003", - Name: "Magic number", - Category: "style", - Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { - var comments []ReviewComment - for i, line := range lines { - if !isChangedLine(i+1, diff) { - continue - } - if strings.HasPrefix(strings.TrimSpace(line), "//") { - continue - } - if ignorePat.MatchString(line) { - continue - } - if numPat.MatchString(line) { - comments = append(comments, ReviewComment{ - File: file, - Line: i + 1, - Severity: "info", - Category: "style", - Message: "Magic number — consider extracting as a named constant", - Suggestion: "Define a const with a descriptive name", - RuleID: "STY003", - }) - } - } - return comments - }, - } -} - -// --- Testing rules --- - -func ruleTestWithoutAssertion() ReviewRule { - testFunc := regexp.MustCompile(`^func\s+Test\w+\(`) - assertion := regexp.MustCompile(`(assert\.|require\.|t\.(Error|Fatal|Fail|Check|Log)|if .* != |expect\(|should)`) - return ReviewRule{ - ID: "TEST001", - Name: "Test without assertion", - Category: "testing", - Language: "go", - Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { - if !strings.HasSuffix(file, "_test.go") { - return nil - } - var comments []ReviewComment - inTest := false - testStart := 0 - braceDepth := 0 - hasAssertion := false - for i, line := range lines { - if testFunc.MatchString(line) { - if inTest && !hasAssertion && isChangedLine(testStart, diff) { - comments = append(comments, ReviewComment{ - File: file, - Line: testStart, - Severity: "warning", - Category: "testing", - Message: "Test function without any assertion", - Suggestion: "Add assertions to validate expected behavior", - RuleID: "TEST001", - }) - } - inTest = true - testStart = i + 1 - braceDepth = 0 - hasAssertion = false - } - if inTest { - braceDepth += strings.Count(line, "{") - strings.Count(line, "}") - if assertion.MatchString(line) { - hasAssertion = true - } - if braceDepth <= 0 && i > testStart { - if !hasAssertion && isChangedLine(testStart, diff) { - comments = append(comments, ReviewComment{ - File: file, - Line: testStart, - Severity: "warning", - Category: "testing", - Message: "Test function without any assertion", - Suggestion: "Add assertions to validate expected behavior", - RuleID: "TEST001", - }) - } - inTest = false - } - } - } - return comments - }, - } -} - -func ruleSkippedTests() ReviewRule { - skipPat := regexp.MustCompile(`t\.Skip\(`) - return ReviewRule{ - ID: "TEST002", - Name: "Skipped test", - Category: "testing", - Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { - if !strings.HasSuffix(file, "_test.go") { - return nil - } - var comments []ReviewComment - for i, line := range lines { - if !isChangedLine(i+1, diff) { - continue - } - if skipPat.MatchString(line) { - comments = append(comments, ReviewComment{ - File: file, - Line: i + 1, - Severity: "info", - Category: "testing", - Message: "Test explicitly skipped", - Suggestion: "Remove t.Skip() or add a TODO with timeline to re-enable", - RuleID: "TEST002", - }) - } - } - return comments - }, - } -} - -func ruleTestFileWithoutTests() ReviewRule { - testFunc := regexp.MustCompile(`^func\s+Test\w+\(`) - return ReviewRule{ - ID: "TEST003", - Name: "Test file without test functions", - Category: "testing", - Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { - if !strings.HasSuffix(file, "_test.go") { - return nil - } - hasTest := false - for _, line := range lines { - if testFunc.MatchString(line) { - hasTest = true - break - } - } - if !hasTest && len(diff) > 0 { - return []ReviewComment{{ - File: file, - Line: 1, - Severity: "warning", - Category: "testing", - Message: "Test file does not contain any test functions", - Suggestion: "Add test functions or remove the _test.go suffix", - RuleID: "TEST003", - }} - } - return nil - }, - } -} - -// --- Additional rules --- - -func ruleHardcodedIP() ReviewRule { - ipPat := regexp.MustCompile(`\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b`) - ignorePat := regexp.MustCompile(`(127\.0\.0\.1|0\.0\.0\.0|localhost|test|example|spec)`) - return ReviewRule{ - ID: "SEC005", - Name: "Hardcoded IP address", - Category: "security", - Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { - var comments []ReviewComment - for i, line := range lines { - if !isChangedLine(i+1, diff) { - continue - } - if strings.HasPrefix(strings.TrimSpace(line), "//") { - continue - } - if ipPat.MatchString(line) && !ignorePat.MatchString(line) { - comments = append(comments, ReviewComment{ - File: file, - Line: i + 1, - Severity: "warning", - Category: "security", - Message: "Hardcoded IP address", - Suggestion: "Use configuration or environment variable for IP addresses", - RuleID: "SEC005", - }) - } - } - return comments - }, - } -} - -func ruleTODOsInCode() ReviewRule { - todoPat := regexp.MustCompile(`(?i)(TODO|FIXME|HACK|XXX|TEMP)\b`) - return ReviewRule{ - ID: "STY004", - Name: "TODO/FIXME comment", - Category: "style", - Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { - var comments []ReviewComment - for i, line := range lines { - if !isChangedLine(i+1, diff) { - continue - } - if todoPat.MatchString(line) { - comments = append(comments, ReviewComment{ - File: file, - Line: i + 1, - Severity: "info", - Category: "style", - Message: "TODO/FIXME comment — track in issue tracker", - Suggestion: "Create an issue and reference its ID in the comment", - RuleID: "STY004", - }) - } - } - return comments - }, - } -} - -func ruleEmptyErrorHandler() ReviewRule { - catchPat := regexp.MustCompile(`if\s+err\s*!=\s*nil\s*\{`) - return ReviewRule{ - ID: "CORR005", - Name: "Empty error handler", - Category: "correctness", - Language: "go", - Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { - var comments []ReviewComment - for i, line := range lines { - if !isChangedLine(i+1, diff) { - continue - } - if catchPat.MatchString(line) { - // Check if the next line is just a closing brace. - if i+1 < len(lines) { - next := strings.TrimSpace(lines[i+1]) - if next == "}" || next == "// ignore" { - comments = append(comments, ReviewComment{ - File: file, - Line: i + 1, - Severity: "warning", - Category: "correctness", - Message: "Empty error handler — error is silently swallowed", - Suggestion: "Log the error or return it to the caller", - RuleID: "CORR005", - }) - } - } - } - } - return comments - }, - } -} - -func ruleDeferInLoop() ReviewRule { - loopPat := regexp.MustCompile(`^\s*for\s`) - deferPat := regexp.MustCompile(`^\s*defer\s`) - return ReviewRule{ - ID: "CORR006", - Name: "Defer inside loop", - Category: "correctness", - Language: "go", - Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { - var comments []ReviewComment - inLoop := false - loopStart := 0 - braceDepth := 0 - for i, line := range lines { - if loopPat.MatchString(line) { - inLoop = true - loopStart = i - braceDepth = 0 - } - if inLoop { - braceDepth += strings.Count(line, "{") - strings.Count(line, "}") - if braceDepth <= 0 && i > loopStart { - inLoop = false - } - } - if inLoop && deferPat.MatchString(line) && isChangedLine(i+1, diff) { - comments = append(comments, ReviewComment{ - File: file, - Line: i + 1, - Severity: "warning", - Category: "correctness", - Message: "defer inside loop — deferred calls won't execute until function returns", - Suggestion: "Move resource cleanup into the loop body or extract to a function", - RuleID: "CORR006", - }) - } - } - return comments - }, - } -} - -func ruleUnusedParameter() ReviewRule { - funcPat := regexp.MustCompile(`^func\s+(?:\(\w+\s+\*?\w+\)\s+)?\w+\(([^)]+)\)`) - return ReviewRule{ - ID: "STY005", - Name: "Potentially unused parameter", - Category: "style", - Language: "go", - Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { - var comments []ReviewComment - for i, line := range lines { - if !isChangedLine(i+1, diff) { - continue - } - m := funcPat.FindStringSubmatch(line) - if len(m) < 2 { - continue - } - params := parseParams(m[1]) - // Scan function body for parameter usage. - braceDepth := 0 - bodyStart := i - for j := i; j < len(lines); j++ { - braceDepth += strings.Count(lines[j], "{") - strings.Count(lines[j], "}") - if braceDepth > 0 { - bodyStart = j + 1 - break - } - } - bodyEnd := bodyStart - bodyDepth := 1 - for j := bodyStart; j < len(lines); j++ { - bodyDepth += strings.Count(lines[j], "{") - strings.Count(lines[j], "}") - if bodyDepth <= 0 { - bodyEnd = j - break - } - } - body := strings.Join(lines[bodyStart:minInt(bodyEnd, len(lines))], "\n") - for _, p := range params { - if p == "_" || p == "" { - continue - } - if !strings.Contains(body, p) { - comments = append(comments, ReviewComment{ - File: file, - Line: i + 1, - Severity: "info", - Category: "style", - Message: fmt.Sprintf("Parameter '%s' appears unused in function body", p), - Suggestion: "Remove unused parameter or prefix with _", - RuleID: "STY005", - }) - break // report once per function - } - } - } - return comments - }, - } -} - // ---------- utility ---------- func parseParams(paramStr string) []string { diff --git a/internal/engine/review/review_bot_rules.go b/internal/engine/review/review_bot_rules.go new file mode 100644 index 00000000..8bdf6b4f --- /dev/null +++ b/internal/engine/review/review_bot_rules.go @@ -0,0 +1,928 @@ +package review + +import ( + "fmt" + "regexp" + "strings" + + "github.com/GrayCodeAI/hawk/internal/engine/diff" +) + +// This file holds the built-in review rule set used by ReviewBot. The engine, +// formatting, diff parsing, and shared helpers live in review_bot.go. + +// ---------- built-in rules ---------- + +func builtinReviewRules() []ReviewRule { + return []ReviewRule{ + ruleHardcodedSecrets(), + ruleSQLInjection(), + ruleCommandInjection(), + ruleXSS(), + ruleNPlusOneQuery(), + ruleUnboundedAllocation(), + ruleStringConcatInLoop(), + ruleErrorIgnored(), + ruleNilDereferenceRisk(), + ruleUnclosedResources(), + ruleRaceCondition(), + ruleExportedWithoutDocs(), + ruleInconsistentNaming(), + ruleMagicNumbers(), + ruleTestWithoutAssertion(), + ruleSkippedTests(), + ruleTestFileWithoutTests(), + ruleHardcodedIP(), + ruleTODOsInCode(), + ruleEmptyErrorHandler(), + ruleDeferInLoop(), + ruleUnusedParameter(), + } +} + +// --- Security rules --- + +func ruleHardcodedSecrets() ReviewRule { + patterns := []*regexp.Regexp{ + regexp.MustCompile(`(?i)(password|secret|api_key|apikey|token|private_key)\s*[:=]\s*["']` + "[^\"']{8,}" + `["']`), + regexp.MustCompile(`(?i)(aws_access_key_id|aws_secret_access_key)\s*[:=]\s*["']`), + regexp.MustCompile(`(?i)-----BEGIN (RSA |EC )?PRIVATE KEY-----`), + regexp.MustCompile(`ghp_[A-Za-z0-9]{36}`), + regexp.MustCompile(`sk-[A-Za-z0-9]{20,}`), + } + return ReviewRule{ + ID: "SEC001", + Name: "Hardcoded secret detected", + Category: "security", + Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { + var comments []ReviewComment + for i, line := range lines { + if !isChangedLine(i+1, diff) { + continue + } + for _, pat := range patterns { + if pat.MatchString(line) { + comments = append(comments, ReviewComment{ + File: file, + Line: i + 1, + Severity: "error", + Category: "security", + Message: "Hardcoded secret detected", + Suggestion: "Use environment variable instead", + RuleID: "SEC001", + }) + break + } + } + } + return comments + }, + } +} + +func ruleSQLInjection() ReviewRule { + patterns := []*regexp.Regexp{ + regexp.MustCompile(`(?i)(fmt\.Sprintf|".*\+.*")\s*.*?(SELECT|INSERT|UPDATE|DELETE|DROP)\s`), + regexp.MustCompile(`(?i)query\s*[:=].*\+\s*\w`), + regexp.MustCompile(`(?i)(Exec|Query|QueryRow)\s*\(\s*fmt\.Sprintf`), + regexp.MustCompile(`(?i)(Exec|Query|QueryRow)\s*\(\s*"[^"]*"\s*\+`), + } + return ReviewRule{ + ID: "SEC002", + Name: "Potential SQL injection", + Category: "security", + Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { + var comments []ReviewComment + for i, line := range lines { + if !isChangedLine(i+1, diff) { + continue + } + for _, pat := range patterns { + if pat.MatchString(line) { + comments = append(comments, ReviewComment{ + File: file, + Line: i + 1, + Severity: "error", + Category: "security", + Message: "Potential SQL injection via string concatenation", + Suggestion: "Use parameterized queries with placeholder arguments", + RuleID: "SEC002", + }) + break + } + } + } + return comments + }, + } +} + +func ruleCommandInjection() ReviewRule { + pattern := regexp.MustCompile(`(?i)(exec\.Command|os\.system|subprocess\.(call|run|Popen))\s*\(.*\+`) + patternFmt := regexp.MustCompile(`(?i)(exec\.Command|os\.system)\s*\(\s*fmt\.Sprintf`) + return ReviewRule{ + ID: "SEC003", + Name: "Potential command injection", + Category: "security", + Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { + var comments []ReviewComment + for i, line := range lines { + if !isChangedLine(i+1, diff) { + continue + } + if pattern.MatchString(line) || patternFmt.MatchString(line) { + comments = append(comments, ReviewComment{ + File: file, + Line: i + 1, + Severity: "error", + Category: "security", + Message: "Potential command injection via string concatenation", + Suggestion: "Sanitize inputs or use argument lists instead of shell strings", + RuleID: "SEC003", + }) + } + } + return comments + }, + } +} + +func ruleXSS() ReviewRule { + patterns := []*regexp.Regexp{ + regexp.MustCompile(`(?i)innerHTML\s*=`), + regexp.MustCompile(`(?i)document\.write\s*\(`), + regexp.MustCompile(`(?i)fmt\.Fprintf\s*\(\s*w\s*,.*\+`), + regexp.MustCompile(`(?i)template\.HTML\(`), + } + return ReviewRule{ + ID: "SEC004", + Name: "Potential XSS vulnerability", + Category: "security", + Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { + var comments []ReviewComment + for i, line := range lines { + if !isChangedLine(i+1, diff) { + continue + } + for _, pat := range patterns { + if pat.MatchString(line) { + comments = append(comments, ReviewComment{ + File: file, + Line: i + 1, + Severity: "error", + Category: "security", + Message: "Potential cross-site scripting (XSS) vulnerability", + Suggestion: "Sanitize or escape user input before rendering", + RuleID: "SEC004", + }) + break + } + } + } + return comments + }, + } +} + +// --- Performance rules --- + +func ruleNPlusOneQuery() ReviewRule { + queryPat := regexp.MustCompile(`(?i)(\.Query|\.QueryRow|\.Exec)\s*\(`) + loopPat := regexp.MustCompile(`^\s*(for|range)\s`) + return ReviewRule{ + ID: "PERF001", + Name: "Potential N+1 query", + Category: "performance", + Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { + var comments []ReviewComment + inLoop := false + loopStart := 0 + braceDepth := 0 + for i, line := range lines { + if loopPat.MatchString(line) { + inLoop = true + loopStart = i + 1 + braceDepth = 0 + } + if inLoop { + braceDepth += strings.Count(line, "{") - strings.Count(line, "}") + if braceDepth <= 0 && i > loopStart { + inLoop = false + } + } + if inLoop && queryPat.MatchString(line) && isChangedLine(i+1, diff) { + comments = append(comments, ReviewComment{ + File: file, + Line: i + 1, + Severity: "warning", + Category: "performance", + Message: "Database query inside loop (potential N+1 query)", + Suggestion: "Batch queries or use a JOIN to fetch all data at once", + RuleID: "PERF001", + }) + } + } + return comments + }, + } +} + +func ruleUnboundedAllocation() ReviewRule { + pattern := regexp.MustCompile(`make\s*\(\s*\[\]\w+\s*,\s*\w+`) + return ReviewRule{ + ID: "PERF002", + Name: "Potentially unbounded allocation", + Category: "performance", + Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { + var comments []ReviewComment + for i, line := range lines { + if !isChangedLine(i+1, diff) { + continue + } + if pattern.MatchString(line) && !strings.Contains(line, "cap") { + // Check if the size variable might be user-controlled. + if strings.Contains(line, "len(") || strings.Contains(line, "req.") || strings.Contains(line, "input") { + comments = append(comments, ReviewComment{ + File: file, + Line: i + 1, + Severity: "warning", + Category: "performance", + Message: "Potentially unbounded allocation based on external input", + Suggestion: "Add a maximum cap to prevent OOM from large inputs", + RuleID: "PERF002", + }) + } + } + } + return comments + }, + } +} + +func ruleStringConcatInLoop() ReviewRule { + concatPat := regexp.MustCompile(`\w+\s*\+=\s*("|\w)`) + loopPat := regexp.MustCompile(`^\s*(for|range)\s`) + return ReviewRule{ + ID: "PERF003", + Name: "String concatenation in loop", + Category: "performance", + Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { + var comments []ReviewComment + inLoop := false + loopStart := 0 + braceDepth := 0 + for i, line := range lines { + if loopPat.MatchString(line) { + inLoop = true + loopStart = i + 1 + braceDepth = 0 + } + if inLoop { + braceDepth += strings.Count(line, "{") - strings.Count(line, "}") + if braceDepth <= 0 && i > loopStart { + inLoop = false + } + } + if inLoop && concatPat.MatchString(line) && isChangedLine(i+1, diff) { + if strings.Contains(line, "string") || strings.Contains(line, `"`) || strings.Contains(line, "str") || strings.Contains(line, "result") || strings.Contains(line, "output") { + comments = append(comments, ReviewComment{ + File: file, + Line: i + 1, + Severity: "warning", + Category: "performance", + Message: "String concatenation inside loop", + Suggestion: "Use strings.Builder for better performance", + RuleID: "PERF003", + }) + } + } + } + return comments + }, + } +} + +// --- Correctness rules --- + +func ruleErrorIgnored() ReviewRule { + patterns := []*regexp.Regexp{ + regexp.MustCompile(`\w+,\s*_\s*:?=\s*\w+.*\(`), + regexp.MustCompile(`^\s*\w+\.\w+\(.*\)\s*$`), + } + errFuncPat := regexp.MustCompile(`(?i)(write|close|flush|send|remove|delete|create)`) + return ReviewRule{ + ID: "CORR001", + Name: "Error return value ignored", + Category: "correctness", + Language: "go", + Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { + var comments []ReviewComment + for i, line := range lines { + if !isChangedLine(i+1, diff) { + continue + } + // Pattern: val, _ := someFunc() + if patterns[0].MatchString(line) && errFuncPat.MatchString(line) { + comments = append(comments, ReviewComment{ + File: file, + Line: i + 1, + Severity: "warning", + Category: "correctness", + Message: "Error return value ignored", + Suggestion: "if err != nil { return err }", + RuleID: "CORR001", + }) + } + } + return comments + }, + } +} + +func ruleNilDereferenceRisk() ReviewRule { + returnPat := regexp.MustCompile(`return\s+nil\s*,`) + usePat := regexp.MustCompile(`(\w+)\s*,\s*err\s*:?=`) + return ReviewRule{ + ID: "CORR002", + Name: "Nil dereference risk", + Category: "correctness", + Language: "go", + Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { + var comments []ReviewComment + _ = returnPat + for i, line := range lines { + if !isChangedLine(i+1, diff) { + continue + } + if usePat.MatchString(line) { + // Check if next non-blank line uses the variable without nil check. + if i+1 < len(lines) { + nextLine := strings.TrimSpace(lines[i+1]) + varMatch := usePat.FindStringSubmatch(line) + if len(varMatch) > 1 && !strings.Contains(nextLine, "err") && !strings.Contains(nextLine, "nil") && strings.Contains(nextLine, varMatch[1]+".") { + comments = append(comments, ReviewComment{ + File: file, + Line: i + 2, + Severity: "warning", + Category: "correctness", + Message: "Potential nil dereference — value used without checking error", + Suggestion: "Check err != nil before using " + varMatch[1], + RuleID: "CORR002", + }) + } + } + } + } + return comments + }, + } +} + +func ruleUnclosedResources() ReviewRule { + openPat := regexp.MustCompile(`(os\.Open|sql\.Open|net\.Dial|http\.Get)\s*\(`) + deferPat := regexp.MustCompile(`defer\s+\w+\.Close\(\)`) + return ReviewRule{ + ID: "CORR003", + Name: "Unclosed resource", + Category: "correctness", + Language: "go", + Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { + var comments []ReviewComment + for i, line := range lines { + if !isChangedLine(i+1, diff) { + continue + } + if openPat.MatchString(line) { + // Look for defer close within next 5 lines. + found := false + for j := i + 1; j < i+6 && j < len(lines); j++ { + if deferPat.MatchString(lines[j]) || strings.Contains(lines[j], ".Close()") { + found = true + break + } + } + if !found { + comments = append(comments, ReviewComment{ + File: file, + Line: i + 1, + Severity: "warning", + Category: "correctness", + Message: "Resource opened without visible defer Close()", + Suggestion: "Add defer resource.Close() after error check", + RuleID: "CORR003", + }) + } + } + } + return comments + }, + } +} + +func ruleRaceCondition() ReviewRule { + goPat := regexp.MustCompile(`^\s*go\s+\w+`) + sharedPat := regexp.MustCompile(`(shared|global|counter|state)\w*\s*[\+\-]?=`) + return ReviewRule{ + ID: "CORR004", + Name: "Potential race condition", + Category: "correctness", + Language: "go", + Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { + var comments []ReviewComment + hasGoroutine := false + for _, line := range lines { + if goPat.MatchString(line) { + hasGoroutine = true + break + } + } + if !hasGoroutine { + return nil + } + for i, line := range lines { + if !isChangedLine(i+1, diff) { + continue + } + if sharedPat.MatchString(line) { + // Check if there's a mutex lock nearby. + hasMutex := false + for j := maxInt(0, i-5); j < i; j++ { + if strings.Contains(lines[j], "Lock()") || strings.Contains(lines[j], "RLock()") { + hasMutex = true + break + } + } + if !hasMutex { + comments = append(comments, ReviewComment{ + File: file, + Line: i + 1, + Severity: "warning", + Category: "correctness", + Message: "Potential race condition — shared state modified without visible lock", + Suggestion: "Protect with sync.Mutex or use atomic operations", + RuleID: "CORR004", + }) + } + } + } + return comments + }, + } +} + +// --- Style rules --- + +func ruleExportedWithoutDocs() ReviewRule { + exportedFunc := regexp.MustCompile(`^func\s+([A-Z]\w*)`) + exportedMethod := regexp.MustCompile(`^func\s+\(\w+\s+\*?\w+\)\s+([A-Z]\w*)`) + exportedType := regexp.MustCompile(`^type\s+([A-Z]\w*)`) + commentPat := regexp.MustCompile(`^//\s*`) + return ReviewRule{ + ID: "STY001", + Name: "Exported symbol without documentation", + Category: "style", + Language: "go", + Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { + // Skip test files. + if strings.HasSuffix(file, "_test.go") { + return nil + } + var comments []ReviewComment + for i, line := range lines { + if !isChangedLine(i+1, diff) { + continue + } + var name string + if m := exportedMethod.FindStringSubmatch(line); len(m) > 1 { + name = m[1] + } else if m := exportedFunc.FindStringSubmatch(line); len(m) > 1 { + name = m[1] + } else if m := exportedType.FindStringSubmatch(line); len(m) > 1 { + name = m[1] + } + if name == "" { + continue + } + // Check previous line for comment. + hasDoc := false + if i > 0 && commentPat.MatchString(strings.TrimSpace(lines[i-1])) { + hasDoc = true + } + if !hasDoc { + comments = append(comments, ReviewComment{ + File: file, + Line: i + 1, + Severity: "info", + Category: "style", + Message: "Exported function missing documentation", + Suggestion: fmt.Sprintf("Add godoc comment: // %s ...", name), + RuleID: "STY001", + }) + } + } + return comments + }, + } +} + +func ruleInconsistentNaming() ReviewRule { + snakePat := regexp.MustCompile(`\b[a-z]+_[a-z]+\b`) + varDecl := regexp.MustCompile(`(var|:=)\s+`) + return ReviewRule{ + ID: "STY002", + Name: "Inconsistent naming convention", + Category: "style", + Language: "go", + Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { + var comments []ReviewComment + for i, line := range lines { + if !isChangedLine(i+1, diff) { + continue + } + if varDecl.MatchString(line) && snakePat.MatchString(line) { + // Ignore struct tags and strings. + trimmed := removeStrings(line) + if snakePat.MatchString(trimmed) { + comments = append(comments, ReviewComment{ + File: file, + Line: i + 1, + Severity: "info", + Category: "style", + Message: "Snake_case identifier in Go code (use camelCase)", + Suggestion: "Rename to camelCase per Go conventions", + RuleID: "STY002", + }) + } + } + } + return comments + }, + } +} + +func ruleMagicNumbers() ReviewRule { + numPat := regexp.MustCompile(`[^0-9\.]([2-9]\d{2,}|[1-9]\d{3,})([^0-9\.]|$)`) + ignorePat := regexp.MustCompile(`(const|http\.|port|timeout|test|spec|0x|version|v\d)`) + return ReviewRule{ + ID: "STY003", + Name: "Magic number", + Category: "style", + Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { + var comments []ReviewComment + for i, line := range lines { + if !isChangedLine(i+1, diff) { + continue + } + if strings.HasPrefix(strings.TrimSpace(line), "//") { + continue + } + if ignorePat.MatchString(line) { + continue + } + if numPat.MatchString(line) { + comments = append(comments, ReviewComment{ + File: file, + Line: i + 1, + Severity: "info", + Category: "style", + Message: "Magic number — consider extracting as a named constant", + Suggestion: "Define a const with a descriptive name", + RuleID: "STY003", + }) + } + } + return comments + }, + } +} + +// --- Testing rules --- + +func ruleTestWithoutAssertion() ReviewRule { + testFunc := regexp.MustCompile(`^func\s+Test\w+\(`) + assertion := regexp.MustCompile(`(assert\.|require\.|t\.(Error|Fatal|Fail|Check|Log)|if .* != |expect\(|should)`) + return ReviewRule{ + ID: "TEST001", + Name: "Test without assertion", + Category: "testing", + Language: "go", + Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { + if !strings.HasSuffix(file, "_test.go") { + return nil + } + var comments []ReviewComment + inTest := false + testStart := 0 + braceDepth := 0 + hasAssertion := false + for i, line := range lines { + if testFunc.MatchString(line) { + if inTest && !hasAssertion && isChangedLine(testStart, diff) { + comments = append(comments, ReviewComment{ + File: file, + Line: testStart, + Severity: "warning", + Category: "testing", + Message: "Test function without any assertion", + Suggestion: "Add assertions to validate expected behavior", + RuleID: "TEST001", + }) + } + inTest = true + testStart = i + 1 + braceDepth = 0 + hasAssertion = false + } + if inTest { + braceDepth += strings.Count(line, "{") - strings.Count(line, "}") + if assertion.MatchString(line) { + hasAssertion = true + } + if braceDepth <= 0 && i > testStart { + if !hasAssertion && isChangedLine(testStart, diff) { + comments = append(comments, ReviewComment{ + File: file, + Line: testStart, + Severity: "warning", + Category: "testing", + Message: "Test function without any assertion", + Suggestion: "Add assertions to validate expected behavior", + RuleID: "TEST001", + }) + } + inTest = false + } + } + } + return comments + }, + } +} + +func ruleSkippedTests() ReviewRule { + skipPat := regexp.MustCompile(`t\.Skip\(`) + return ReviewRule{ + ID: "TEST002", + Name: "Skipped test", + Category: "testing", + Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { + if !strings.HasSuffix(file, "_test.go") { + return nil + } + var comments []ReviewComment + for i, line := range lines { + if !isChangedLine(i+1, diff) { + continue + } + if skipPat.MatchString(line) { + comments = append(comments, ReviewComment{ + File: file, + Line: i + 1, + Severity: "info", + Category: "testing", + Message: "Test explicitly skipped", + Suggestion: "Remove t.Skip() or add a TODO with timeline to re-enable", + RuleID: "TEST002", + }) + } + } + return comments + }, + } +} + +func ruleTestFileWithoutTests() ReviewRule { + testFunc := regexp.MustCompile(`^func\s+Test\w+\(`) + return ReviewRule{ + ID: "TEST003", + Name: "Test file without test functions", + Category: "testing", + Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { + if !strings.HasSuffix(file, "_test.go") { + return nil + } + hasTest := false + for _, line := range lines { + if testFunc.MatchString(line) { + hasTest = true + break + } + } + if !hasTest && len(diff) > 0 { + return []ReviewComment{{ + File: file, + Line: 1, + Severity: "warning", + Category: "testing", + Message: "Test file does not contain any test functions", + Suggestion: "Add test functions or remove the _test.go suffix", + RuleID: "TEST003", + }} + } + return nil + }, + } +} + +// --- Additional rules --- + +func ruleHardcodedIP() ReviewRule { + ipPat := regexp.MustCompile(`\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b`) + ignorePat := regexp.MustCompile(`(127\.0\.0\.1|0\.0\.0\.0|localhost|test|example|spec)`) + return ReviewRule{ + ID: "SEC005", + Name: "Hardcoded IP address", + Category: "security", + Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { + var comments []ReviewComment + for i, line := range lines { + if !isChangedLine(i+1, diff) { + continue + } + if strings.HasPrefix(strings.TrimSpace(line), "//") { + continue + } + if ipPat.MatchString(line) && !ignorePat.MatchString(line) { + comments = append(comments, ReviewComment{ + File: file, + Line: i + 1, + Severity: "warning", + Category: "security", + Message: "Hardcoded IP address", + Suggestion: "Use configuration or environment variable for IP addresses", + RuleID: "SEC005", + }) + } + } + return comments + }, + } +} + +func ruleTODOsInCode() ReviewRule { + todoPat := regexp.MustCompile(`(?i)(TODO|FIXME|HACK|XXX|TEMP)\b`) + return ReviewRule{ + ID: "STY004", + Name: "TODO/FIXME comment", + Category: "style", + Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { + var comments []ReviewComment + for i, line := range lines { + if !isChangedLine(i+1, diff) { + continue + } + if todoPat.MatchString(line) { + comments = append(comments, ReviewComment{ + File: file, + Line: i + 1, + Severity: "info", + Category: "style", + Message: "TODO/FIXME comment — track in issue tracker", + Suggestion: "Create an issue and reference its ID in the comment", + RuleID: "STY004", + }) + } + } + return comments + }, + } +} + +func ruleEmptyErrorHandler() ReviewRule { + catchPat := regexp.MustCompile(`if\s+err\s*!=\s*nil\s*\{`) + return ReviewRule{ + ID: "CORR005", + Name: "Empty error handler", + Category: "correctness", + Language: "go", + Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { + var comments []ReviewComment + for i, line := range lines { + if !isChangedLine(i+1, diff) { + continue + } + if catchPat.MatchString(line) { + // Check if the next line is just a closing brace. + if i+1 < len(lines) { + next := strings.TrimSpace(lines[i+1]) + if next == "}" || next == "// ignore" { + comments = append(comments, ReviewComment{ + File: file, + Line: i + 1, + Severity: "warning", + Category: "correctness", + Message: "Empty error handler — error is silently swallowed", + Suggestion: "Log the error or return it to the caller", + RuleID: "CORR005", + }) + } + } + } + } + return comments + }, + } +} + +func ruleDeferInLoop() ReviewRule { + loopPat := regexp.MustCompile(`^\s*for\s`) + deferPat := regexp.MustCompile(`^\s*defer\s`) + return ReviewRule{ + ID: "CORR006", + Name: "Defer inside loop", + Category: "correctness", + Language: "go", + Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { + var comments []ReviewComment + inLoop := false + loopStart := 0 + braceDepth := 0 + for i, line := range lines { + if loopPat.MatchString(line) { + inLoop = true + loopStart = i + braceDepth = 0 + } + if inLoop { + braceDepth += strings.Count(line, "{") - strings.Count(line, "}") + if braceDepth <= 0 && i > loopStart { + inLoop = false + } + } + if inLoop && deferPat.MatchString(line) && isChangedLine(i+1, diff) { + comments = append(comments, ReviewComment{ + File: file, + Line: i + 1, + Severity: "warning", + Category: "correctness", + Message: "defer inside loop — deferred calls won't execute until function returns", + Suggestion: "Move resource cleanup into the loop body or extract to a function", + RuleID: "CORR006", + }) + } + } + return comments + }, + } +} + +func ruleUnusedParameter() ReviewRule { + funcPat := regexp.MustCompile(`^func\s+(?:\(\w+\s+\*?\w+\)\s+)?\w+\(([^)]+)\)`) + return ReviewRule{ + ID: "STY005", + Name: "Potentially unused parameter", + Category: "style", + Language: "go", + Check: func(file string, lines []string, diff []diff.DiffLine) []ReviewComment { + var comments []ReviewComment + for i, line := range lines { + if !isChangedLine(i+1, diff) { + continue + } + m := funcPat.FindStringSubmatch(line) + if len(m) < 2 { + continue + } + params := parseParams(m[1]) + // Scan function body for parameter usage. + braceDepth := 0 + bodyStart := i + for j := i; j < len(lines); j++ { + braceDepth += strings.Count(lines[j], "{") - strings.Count(lines[j], "}") + if braceDepth > 0 { + bodyStart = j + 1 + break + } + } + bodyEnd := bodyStart + bodyDepth := 1 + for j := bodyStart; j < len(lines); j++ { + bodyDepth += strings.Count(lines[j], "{") - strings.Count(lines[j], "}") + if bodyDepth <= 0 { + bodyEnd = j + break + } + } + body := strings.Join(lines[bodyStart:minInt(bodyEnd, len(lines))], "\n") + for _, p := range params { + if p == "_" || p == "" { + continue + } + if !strings.Contains(body, p) { + comments = append(comments, ReviewComment{ + File: file, + Line: i + 1, + Severity: "info", + Category: "style", + Message: fmt.Sprintf("Parameter '%s' appears unused in function body", p), + Suggestion: "Remove unused parameter or prefix with _", + RuleID: "STY005", + }) + break // report once per function + } + } + } + return comments + }, + } +} From d3f2afa300b5cd90f956b3973885fb008da81ccd Mon Sep 17 00:00:00 2001 From: Lakshman Patel Date: Fri, 19 Jun 2026 18:47:22 +0530 Subject: [PATCH 03/20] refactor(agents): move built-in persona definitions to persona_builtins.go --- internal/multiagent/agents/persona.go | 490 ----------------- .../multiagent/agents/persona_builtins.go | 496 ++++++++++++++++++ 2 files changed, 496 insertions(+), 490 deletions(-) create mode 100644 internal/multiagent/agents/persona_builtins.go diff --git a/internal/multiagent/agents/persona.go b/internal/multiagent/agents/persona.go index a1d8a94e..a9b28250 100644 --- a/internal/multiagent/agents/persona.go +++ b/internal/multiagent/agents/persona.go @@ -544,496 +544,6 @@ func RenderPersonaFile(persona *Persona) string { return sb.String() } -// BuiltinPersonas returns the set of built-in personas that are auto-created on first run. -func BuiltinPersonas() []*Persona { - now := time.Now() - return []*Persona{ - { - Name: "default", - Description: "Balanced general-purpose coding assistant", - Model: "", - Temperature: 0.5, - MaxTokens: 8192, - Expertise: []string{"backend", "frontend", "testing"}, - CommunicationStyle: "detailed", - SystemPrompt: "You are a skilled software engineer. Help with coding tasks across the full stack. Write clean, idiomatic code with appropriate tests.", - Rules: []string{ - "Follow existing code style and conventions", - "Include error handling", - "Suggest tests for new functionality", - }, - CreatedAt: now, - }, - { - Name: "reviewer", - Description: "Security and correctness focused code reviewer", - Model: "", // inherit session model (was claude-sonnet-4-6) - Temperature: 0.2, - Expertise: []string{"security", "backend", "testing"}, - CommunicationStyle: "concise", - Tools: []string{"Read", "Grep", "Glob", "Bash"}, - SystemPrompt: "You are a thorough code reviewer specializing in security and correctness. Analyze code changes for vulnerabilities, bugs, and improvements.", - Rules: []string{ - "Always check for SQL injection and XSS", - "Flag hardcoded secrets and credentials", - "Verify proper input validation", - "Check error handling completeness", - "Look for race conditions in concurrent code", - }, - CreatedAt: now, - }, - { - Name: "architect", - Description: "High-level system design with minimal code", - Model: "", // inherit session model (was claude-opus-4-6) - Temperature: 0.7, - MaxTokens: 16384, - Expertise: []string{"backend", "devops"}, - CommunicationStyle: "detailed", - ExcludedTools: []string{"Edit", "Write"}, - SystemPrompt: "You are a software architect. Focus on system design, API contracts, and architectural decisions. Prefer diagrams and high-level descriptions over implementation details.", - Rules: []string{ - "Prefer high-level design over implementation", - "Consider scalability and maintainability", - "Document trade-offs explicitly", - "Suggest technology choices with rationale", - }, - CreatedAt: now, - }, - { - Name: "debugger", - Description: "Systematic bug hunter with diagnostic approach", - Model: "", - Temperature: 0.3, - Expertise: []string{"backend", "testing"}, - CommunicationStyle: "detailed", - SystemPrompt: "You are a systematic debugger. Use a scientific approach: observe symptoms, form hypotheses, design experiments, and narrow down root causes methodically.", - Rules: []string{ - "Start by reproducing the bug", - "Form hypotheses before diving into code", - "Use binary search to narrow down causes", - "Check recent changes first", - "Verify the fix does not introduce regressions", - }, - Examples: []PersonaExample{ - { - Input: "The server returns 500 on login", - Output: "Let me systematically diagnose this: 1) Check server logs for the stack trace, 2) Reproduce with curl, 3) Identify the failing handler, 4) Trace the auth flow", - Context: "Web application debugging", - }, - }, - CreatedAt: now, - }, - { - Name: "teacher", - Description: "Explains concepts with tutorial style", - Model: "", - Temperature: 0.6, - MaxTokens: 16384, - Expertise: []string{"frontend", "backend", "testing"}, - CommunicationStyle: "tutorial", - SystemPrompt: "You are a patient teacher and mentor. Explain concepts clearly with examples. Build understanding from fundamentals up. Use analogies to clarify complex ideas.", - Rules: []string{ - "Explain the 'why' before the 'how'", - "Use simple analogies for complex concepts", - "Provide runnable examples", - "Build from simple to complex", - "Anticipate common misconceptions", - }, - CreatedAt: now, - }, - { - Name: "speed", - Description: "Fast and concise, uses cheapest model", - Model: "", // inherit session model (was claude-haiku-3-5) - Temperature: 0.3, - MaxTokens: 4096, - Expertise: []string{"backend", "frontend"}, - CommunicationStyle: "concise", - SystemPrompt: "Be fast and direct. Provide minimal but correct answers. Skip explanations unless asked. Prioritize working code over perfect code.", - Rules: []string{ - "Keep responses under 200 words when possible", - "Skip preamble and get straight to code", - "Only explain if explicitly asked", - "Prefer simple solutions over clever ones", - }, - CreatedAt: now, - }, - { - Name: "planner", - Description: "Decomposes complex tasks into ordered, actionable steps", - Temperature: 0.4, - MaxTokens: 8192, - Expertise: []string{"planning", "backend"}, - CommunicationStyle: "detailed", - ExcludedTools: []string{"Edit", "Write"}, - SystemPrompt: "You are a planning specialist. Break complex problems into clear, sequential, independently-testable steps. Identify dependencies and risks before any code is written.", - Rules: []string{ - "Always identify dependencies between steps", - "Estimate relative effort for each step", - "Flag blockers and risks early", - "Order steps to keep the build green at each stage", - }, - CreatedAt: now, - }, - { - Name: "executor", - Description: "Focused implementer that writes code to spec", - Temperature: 0.3, - MaxTokens: 8192, - Expertise: []string{"backend", "frontend"}, - CommunicationStyle: "concise", - SystemPrompt: "You are a focused implementer. Given a clear spec or plan, write correct, idiomatic code that satisfies the acceptance criteria. Do not expand scope beyond what is specified.", - Rules: []string{ - "Implement exactly what the spec requires, no more", - "Follow existing code style and conventions", - "Run tests after each change", - "Stop and ask if the spec is ambiguous", - }, - CreatedAt: now, - }, - { - Name: "critic", - Description: "Reviews plans and code for flaws before commitment", - Model: "", // inherit session model (was claude-sonnet-4-6) - Temperature: 0.2, - Expertise: []string{"backend", "testing", "security"}, - CommunicationStyle: "concise", - ExcludedTools: []string{"Edit", "Write"}, - SystemPrompt: "You are a constructive critic. Examine plans and code for gaps, risks, edge cases, and over-engineering. Default to skepticism: assume there is a flaw and try to find it.", - Rules: []string{ - "Identify what breaks if each step fails", - "Flag missing edge cases and error paths", - "Call out over-engineering and unnecessary complexity", - "Suggest simpler alternatives when they exist", - }, - CreatedAt: now, - }, - { - Name: "security-reviewer", - Description: "Deep security-focused code reviewer", - Model: "", // inherit session model (was claude-sonnet-4-6) - Temperature: 0.2, - MaxTokens: 8192, - Expertise: []string{"security", "backend"}, - CommunicationStyle: "concise", - Tools: []string{"Read", "Grep", "Glob", "Bash"}, - ExcludedTools: []string{"Edit", "Write"}, - SystemPrompt: "You are a security expert. Focus on the OWASP Top 10, secret handling, authentication and authorization flaws, and input validation. Assume hostile input.", - Rules: []string{ - "Always check for injection (SQL, command, XSS)", - "Flag hardcoded secrets and weak crypto", - "Verify authentication and authorization on every entry point", - "Check for insecure deserialization and SSRF", - }, - CreatedAt: now, - }, - { - Name: "test-engineer", - Description: "Generates tests and analyzes coverage", - Temperature: 0.3, - MaxTokens: 8192, - Expertise: []string{"testing", "backend"}, - CommunicationStyle: "detailed", - SystemPrompt: "You are a test engineer. Write thorough, maintainable tests that cover happy paths, edge cases, and failure modes. Prefer table-driven tests where idiomatic.", - Rules: []string{ - "Cover happy path, edge cases, and error paths", - "Make tests deterministic and isolated", - "Use table-driven tests where the language supports them", - "Test behavior, not implementation details", - }, - CreatedAt: now, - }, - { - Name: "tracer", - Description: "Debugging and trace analysis specialist", - Temperature: 0.3, - Expertise: []string{"tracing", "testing", "backend"}, - CommunicationStyle: "detailed", - SystemPrompt: "You are an observability specialist. Diagnose issues by analyzing logs, traces, and telemetry. Add instrumentation where visibility is missing.", - Rules: []string{ - "Follow the data: logs, traces, metrics before code", - "Reconstruct the timeline of events", - "Add instrumentation to fill visibility gaps", - "Correlate across services using trace IDs", - }, - CreatedAt: now, - }, - { - Name: "verifier", - Description: "Validates implementations against specifications", - Model: "", // inherit session model (was claude-sonnet-4-6) - Temperature: 0.2, - Expertise: []string{"testing", "backend"}, - CommunicationStyle: "concise", - Tools: []string{"Read", "Grep", "Glob", "Bash"}, - SystemPrompt: "You are a verification specialist. Given a spec and an implementation, confirm whether each acceptance criterion is met. Report concrete pass/fail evidence.", - Rules: []string{ - "Check each acceptance criterion individually", - "Provide evidence for every pass or fail verdict", - "Run the actual tests rather than assuming", - "Report partial completion honestly", - }, - CreatedAt: now, - }, - { - // validator is the read-only half of an implement-then-validate - // agent pair: a separate agent reviews the implementation worker's - // output without the ability to change it. Unlike verifier it is - // ReadOnly (no Bash), so its sign-off cannot be tainted by mutating - // the very code it judges. - Name: "validator", - Description: "Read-only validator of an implementation it did not write", - // Model intentionally left empty: a validator should run on whatever - // model the user has configured for the session rather than pinning a - // specific name that may not exist on their provider. (Several - // built-ins pin claude-sonnet-4-6; this one deliberately inherits.) - Model: "", - Temperature: 0.1, - Expertise: []string{"testing", "backend", "security"}, - CommunicationStyle: "concise", - ReadOnly: true, - Tools: []string{"Read", "Grep", "Glob", "LS"}, - ExcludedTools: []string{"Edit", "Write", "Bash"}, - SystemPrompt: "You are a read-only validation agent. You did not write the code under review and you cannot modify it. Inspect the implementation against the stated expected behavior and report, per acceptance criterion, a concrete PASS or FAIL with file:line evidence. Never assume — cite what you actually read.", - Rules: []string{ - "You are read-only: never propose to edit, write, or run shell commands", - "Cite file:line evidence for every PASS or FAIL", - "Judge against the expected behavior, not your own preferences", - "Report partial or unclear completion honestly rather than rounding up", - }, - CreatedAt: now, - }, - { - Name: "integrator", - Description: "Handles merges, integration, and compatibility", - Temperature: 0.3, - Expertise: []string{"integration", "backend", "devops"}, - CommunicationStyle: "detailed", - SystemPrompt: "You are an integration specialist. Resolve merge conflicts, reconcile interfaces, and ensure components work together. Preserve backward compatibility where required.", - Rules: []string{ - "Preserve backward compatibility unless told otherwise", - "Verify interface contracts on both sides", - "Resolve conflicts by understanding intent, not just text", - "Run integration tests after merging", - }, - CreatedAt: now, - }, - { - Name: "documenter", - Description: "Writes documentation and changelogs", - Temperature: 0.5, - MaxTokens: 16384, - Expertise: []string{"documentation"}, - CommunicationStyle: "tutorial", - SystemPrompt: "You are a technical writer. Produce clear, accurate documentation: READMEs, API docs, changelogs, and inline comments. Write for the reader who knows nothing about the change.", - Rules: []string{ - "Lead with what the reader needs to do", - "Include runnable examples", - "Keep changelogs user-facing, not commit-by-commit", - "Document the 'why' for non-obvious decisions", - }, - CreatedAt: now, - }, - { - Name: "devops", - Description: "CI/CD, deployment, and infrastructure specialist", - Temperature: 0.3, - Expertise: []string{"devops", "backend"}, - CommunicationStyle: "detailed", - SystemPrompt: "You are a DevOps engineer. Handle CI/CD pipelines, containerization, deployment, and infrastructure-as-code. Prioritize reproducibility, security, and observability.", - Rules: []string{ - "Make builds reproducible and cacheable", - "Never bake secrets into images or configs", - "Add health checks and observability hooks", - "Prefer declarative infrastructure-as-code", - }, - CreatedAt: now, - }, - { - Name: "performance", - Description: "Performance profiling and optimization specialist", - Temperature: 0.3, - Expertise: []string{"performance", "backend"}, - CommunicationStyle: "detailed", - SystemPrompt: "You are a performance engineer. Profile before optimizing, measure after. Focus on algorithmic complexity, allocations, and hot paths. Avoid premature optimization.", - Rules: []string{ - "Always measure before and after optimizing", - "Identify the actual bottleneck with profiling", - "Prefer algorithmic improvements over micro-optimizations", - "Document the performance impact with numbers", - }, - CreatedAt: now, - }, - { - Name: "refactorer", - Description: "Code cleanup and refactoring specialist", - Temperature: 0.3, - Expertise: []string{"refactoring", "backend", "frontend"}, - CommunicationStyle: "concise", - SystemPrompt: "You are a refactoring specialist. Improve code structure without changing behavior. Make small, atomic, test-backed changes. Reduce duplication and complexity.", - Rules: []string{ - "Never change behavior during a refactor", - "Make small atomic moves, test after each", - "Reduce duplication and cyclomatic complexity", - "Ensure tests pass before and after every step", - }, - CreatedAt: now, - }, - // --- Cavecrew personas (built into GrayCode Hawk) --- - // Three compact, opinionated personas for multi-agent crews. - // Each enforces a strict output format so downstream agents - // can parse the output mechanically. - { - Name: "cavecrew-investigator", - Description: "Compact code investigator with strict 6-word note format", - Temperature: 0.2, - Expertise: []string{"tracing", "backend", "testing"}, - CommunicationStyle: "concise", - Tools: []string{"Read", "Grep", "Glob", "Bash"}, - SystemPrompt: "You are a code investigator. Read code and produce compact notes in the strict format `path:line — symbol — note` where the note is at most 6 words. Every note MUST follow that exact format. No prose, no explanations, no commentary outside the notes. Maximum 20 notes per response. Each note must be on its own line.", - Rules: []string{ - "Every note MUST be `path:line — symbol — note`", - "Notes are at most 6 words after the dash", - "Never use prose, paragraphs, or headings", - "Skip files that don't relate to the question", - "Order notes by importance, most useful first", - }, - Examples: []PersonaExample{ - { - Input: "Where is the cache invalidated?", - Output: "internal/cache/cache.go:42 — Invalidate() — drops all keys\ninternal/api/handlers.go:88 — put() — calls cache.Invalidate", - Context: "Investigating cache invalidation flow", - }, - }, - CreatedAt: now, - }, - { - Name: "cavecrew-builder", - Description: "Focused implementer that refuses multi-file sprawl", - Temperature: 0.3, - Expertise: []string{"backend", "frontend", "testing"}, - CommunicationStyle: "concise", - SystemPrompt: "You are a focused implementer. Given a single-file scope, write correct, idiomatic code. You HARD-REFUSE to edit 3 or more files in one task; if the work spans more than 2 files, split the work into sub-tasks and ask the caller to assign them. Do not expand scope. Do not refactor adjacent code. Do not add dependencies. Do exactly what the spec says, no more.", - Rules: []string{ - "Hard-refuse tasks that touch 3+ files; ask the caller to split", - "Edit at most 2 files per task", - "Do not refactor code outside the spec", - "Do not add new dependencies without explicit approval", - "Run tests after the change; report pass/fail", - "Stop and ask if the spec is ambiguous", - }, - CreatedAt: now, - }, - { - Name: "cavecrew-reviewer", - Description: "Strict severity-coded reviewer with emoji verdicts", - Temperature: 0.2, - Expertise: []string{"security", "backend", "testing", "refactoring"}, - CommunicationStyle: "concise", - Tools: []string{"Read", "Grep", "Glob", "Bash"}, - ExcludedTools: []string{"Edit", "Write"}, - SystemPrompt: "You are a strict reviewer. Examine the proposed change and report findings using ONLY severity emojis at the start of each line. The four severities are: 🔴 blocker (must fix before merge), 🟡 major (should fix soon), 🔵 minor (nit / style), ❓ question (clarify intent). Each finding is on its own line in the format ` path:line — note`. No prose, no headings, no summary paragraphs. Maximum 30 findings.", - Rules: []string{ - "Every finding MUST start with one of 🔴 🟡 🔵 ❓", - "Format: ` path:line — note`", - "Blockers (🔴) only for security, correctness, or data-loss issues", - "Majors (🟡) for performance, maintainability, or test gaps", - "Minors (🔵) for style, naming, or nitpicks", - "Questions (❓) for ambiguous intent; never assume", - "No prose, no summary, no closing remarks", - }, - Examples: []PersonaExample{ - { - Input: "Review the auth refactor in PR #42", - Output: "🔴 internal/auth/jwt.go:18 — signature never expires, no exp claim\n🟡 internal/auth/jwt.go:55 — error message leaks signing key prefix\n🔵 internal/auth/jwt.go:1 — package comment missing\n❓ internal/auth/jwt.go:30 — why HS256 instead of RS256?", - Context: "Reviewing JWT auth refactor", - }, - }, - CreatedAt: now, - }, - } -} - -// CavecrewPersonas returns just the three cavecrew personas -// (investigator, builder, reviewer) built into GrayCode Hawk. -// These are a strict, format-driven subset of the full BuiltinPersonas -// list; callers that want only the cavecrew subset can use this -// function instead of BuiltinPersonas. -func CavecrewPersonas() []*Persona { - now := time.Now() - return []*Persona{ - { - Name: "cavecrew-investigator", - Description: "Compact code investigator with strict 6-word note format", - Temperature: 0.2, - Expertise: []string{"tracing", "backend", "testing"}, - CommunicationStyle: "concise", - Tools: []string{"Read", "Grep", "Glob", "Bash"}, - SystemPrompt: "You are a code investigator. Read code and produce compact notes in the strict format `path:line — symbol — note` where the note is at most 6 words. Every note MUST follow that exact format. No prose, no explanations, no commentary outside the notes. Maximum 20 notes per response. Each note must be on its own line.", - Rules: []string{ - "Every note MUST be `path:line — symbol — note`", - "Notes are at most 6 words after the dash", - "Never use prose, paragraphs, or headings", - "Skip files that don't relate to the question", - "Order notes by importance, most useful first", - }, - Examples: []PersonaExample{ - { - Input: "Where is the cache invalidated?", - Output: "internal/cache/cache.go:42 — Invalidate() — drops all keys\ninternal/api/handlers.go:88 — put() — calls cache.Invalidate", - Context: "Investigating cache invalidation flow", - }, - }, - CreatedAt: now, - }, - { - Name: "cavecrew-builder", - Description: "Focused implementer that refuses multi-file sprawl", - Temperature: 0.3, - Expertise: []string{"backend", "frontend", "testing"}, - CommunicationStyle: "concise", - SystemPrompt: "You are a focused implementer. Given a single-file scope, write correct, idiomatic code. You HARD-REFUSE to edit 3 or more files in one task; if the work spans more than 2 files, split the work into sub-tasks and ask the caller to assign them. Do not expand scope. Do not refactor adjacent code. Do not add dependencies. Do exactly what the spec says, no more.", - Rules: []string{ - "Hard-refuse tasks that touch 3+ files; ask the caller to split", - "Edit at most 2 files per task", - "Do not refactor code outside the spec", - "Do not add new dependencies without explicit approval", - "Run tests after the change; report pass/fail", - "Stop and ask if the spec is ambiguous", - }, - CreatedAt: now, - }, - { - Name: "cavecrew-reviewer", - Description: "Strict severity-coded reviewer with emoji verdicts", - Temperature: 0.2, - Expertise: []string{"security", "backend", "testing", "refactoring"}, - CommunicationStyle: "concise", - Tools: []string{"Read", "Grep", "Glob", "Bash"}, - ExcludedTools: []string{"Edit", "Write"}, - SystemPrompt: "You are a strict reviewer. Examine the proposed change and report findings using ONLY severity emojis at the start of each line. The four severities are: 🔴 blocker (must fix before merge), 🟡 major (should fix soon), 🔵 minor (nit / style), ❓ question (clarify intent). Each finding is on its own line in the format ` path:line — note`. No prose, no headings, no summary paragraphs. Maximum 30 findings.", - Rules: []string{ - "Every finding MUST start with one of 🔴 🟡 🔵 ❓", - "Format: ` path:line — note`", - "Blockers (🔴) only for security, correctness, or data-loss issues", - "Majors (🟡) for performance, maintainability, or test gaps", - "Minors (🔵) for style, naming, or nitpicks", - "Questions (❓) for ambiguous intent; never assume", - "No prose, no summary, no closing remarks", - }, - Examples: []PersonaExample{ - { - Input: "Review the auth refactor in PR #42", - Output: "🔴 internal/auth/jwt.go:18 — signature never expires, no exp claim\n🟡 internal/auth/jwt.go:55 — error message leaks signing key prefix\n🔵 internal/auth/jwt.go:1 — package comment missing\n❓ internal/auth/jwt.go:30 — why HS256 instead of RS256?", - Context: "Reviewing JWT auth refactor", - }, - }, - CreatedAt: now, - }, - } -} - // EnsureBuiltins creates the built-in personas in the directory if they do not exist. func (r *PersonaRegistry) EnsureBuiltins() error { if err := os.MkdirAll(r.Dir, 0o755); err != nil { diff --git a/internal/multiagent/agents/persona_builtins.go b/internal/multiagent/agents/persona_builtins.go new file mode 100644 index 00000000..434df0d5 --- /dev/null +++ b/internal/multiagent/agents/persona_builtins.go @@ -0,0 +1,496 @@ +package agents + +import "time" + +// This file holds the built-in persona definitions (data). The registry, +// markdown parser/renderer, selection logic, and helpers live in persona.go. + +// BuiltinPersonas returns the set of built-in personas that are auto-created on first run. +func BuiltinPersonas() []*Persona { + now := time.Now() + return []*Persona{ + { + Name: "default", + Description: "Balanced general-purpose coding assistant", + Model: "", + Temperature: 0.5, + MaxTokens: 8192, + Expertise: []string{"backend", "frontend", "testing"}, + CommunicationStyle: "detailed", + SystemPrompt: "You are a skilled software engineer. Help with coding tasks across the full stack. Write clean, idiomatic code with appropriate tests.", + Rules: []string{ + "Follow existing code style and conventions", + "Include error handling", + "Suggest tests for new functionality", + }, + CreatedAt: now, + }, + { + Name: "reviewer", + Description: "Security and correctness focused code reviewer", + Model: "", // inherit session model (was claude-sonnet-4-6) + Temperature: 0.2, + Expertise: []string{"security", "backend", "testing"}, + CommunicationStyle: "concise", + Tools: []string{"Read", "Grep", "Glob", "Bash"}, + SystemPrompt: "You are a thorough code reviewer specializing in security and correctness. Analyze code changes for vulnerabilities, bugs, and improvements.", + Rules: []string{ + "Always check for SQL injection and XSS", + "Flag hardcoded secrets and credentials", + "Verify proper input validation", + "Check error handling completeness", + "Look for race conditions in concurrent code", + }, + CreatedAt: now, + }, + { + Name: "architect", + Description: "High-level system design with minimal code", + Model: "", // inherit session model (was claude-opus-4-6) + Temperature: 0.7, + MaxTokens: 16384, + Expertise: []string{"backend", "devops"}, + CommunicationStyle: "detailed", + ExcludedTools: []string{"Edit", "Write"}, + SystemPrompt: "You are a software architect. Focus on system design, API contracts, and architectural decisions. Prefer diagrams and high-level descriptions over implementation details.", + Rules: []string{ + "Prefer high-level design over implementation", + "Consider scalability and maintainability", + "Document trade-offs explicitly", + "Suggest technology choices with rationale", + }, + CreatedAt: now, + }, + { + Name: "debugger", + Description: "Systematic bug hunter with diagnostic approach", + Model: "", + Temperature: 0.3, + Expertise: []string{"backend", "testing"}, + CommunicationStyle: "detailed", + SystemPrompt: "You are a systematic debugger. Use a scientific approach: observe symptoms, form hypotheses, design experiments, and narrow down root causes methodically.", + Rules: []string{ + "Start by reproducing the bug", + "Form hypotheses before diving into code", + "Use binary search to narrow down causes", + "Check recent changes first", + "Verify the fix does not introduce regressions", + }, + Examples: []PersonaExample{ + { + Input: "The server returns 500 on login", + Output: "Let me systematically diagnose this: 1) Check server logs for the stack trace, 2) Reproduce with curl, 3) Identify the failing handler, 4) Trace the auth flow", + Context: "Web application debugging", + }, + }, + CreatedAt: now, + }, + { + Name: "teacher", + Description: "Explains concepts with tutorial style", + Model: "", + Temperature: 0.6, + MaxTokens: 16384, + Expertise: []string{"frontend", "backend", "testing"}, + CommunicationStyle: "tutorial", + SystemPrompt: "You are a patient teacher and mentor. Explain concepts clearly with examples. Build understanding from fundamentals up. Use analogies to clarify complex ideas.", + Rules: []string{ + "Explain the 'why' before the 'how'", + "Use simple analogies for complex concepts", + "Provide runnable examples", + "Build from simple to complex", + "Anticipate common misconceptions", + }, + CreatedAt: now, + }, + { + Name: "speed", + Description: "Fast and concise, uses cheapest model", + Model: "", // inherit session model (was claude-haiku-3-5) + Temperature: 0.3, + MaxTokens: 4096, + Expertise: []string{"backend", "frontend"}, + CommunicationStyle: "concise", + SystemPrompt: "Be fast and direct. Provide minimal but correct answers. Skip explanations unless asked. Prioritize working code over perfect code.", + Rules: []string{ + "Keep responses under 200 words when possible", + "Skip preamble and get straight to code", + "Only explain if explicitly asked", + "Prefer simple solutions over clever ones", + }, + CreatedAt: now, + }, + { + Name: "planner", + Description: "Decomposes complex tasks into ordered, actionable steps", + Temperature: 0.4, + MaxTokens: 8192, + Expertise: []string{"planning", "backend"}, + CommunicationStyle: "detailed", + ExcludedTools: []string{"Edit", "Write"}, + SystemPrompt: "You are a planning specialist. Break complex problems into clear, sequential, independently-testable steps. Identify dependencies and risks before any code is written.", + Rules: []string{ + "Always identify dependencies between steps", + "Estimate relative effort for each step", + "Flag blockers and risks early", + "Order steps to keep the build green at each stage", + }, + CreatedAt: now, + }, + { + Name: "executor", + Description: "Focused implementer that writes code to spec", + Temperature: 0.3, + MaxTokens: 8192, + Expertise: []string{"backend", "frontend"}, + CommunicationStyle: "concise", + SystemPrompt: "You are a focused implementer. Given a clear spec or plan, write correct, idiomatic code that satisfies the acceptance criteria. Do not expand scope beyond what is specified.", + Rules: []string{ + "Implement exactly what the spec requires, no more", + "Follow existing code style and conventions", + "Run tests after each change", + "Stop and ask if the spec is ambiguous", + }, + CreatedAt: now, + }, + { + Name: "critic", + Description: "Reviews plans and code for flaws before commitment", + Model: "", // inherit session model (was claude-sonnet-4-6) + Temperature: 0.2, + Expertise: []string{"backend", "testing", "security"}, + CommunicationStyle: "concise", + ExcludedTools: []string{"Edit", "Write"}, + SystemPrompt: "You are a constructive critic. Examine plans and code for gaps, risks, edge cases, and over-engineering. Default to skepticism: assume there is a flaw and try to find it.", + Rules: []string{ + "Identify what breaks if each step fails", + "Flag missing edge cases and error paths", + "Call out over-engineering and unnecessary complexity", + "Suggest simpler alternatives when they exist", + }, + CreatedAt: now, + }, + { + Name: "security-reviewer", + Description: "Deep security-focused code reviewer", + Model: "", // inherit session model (was claude-sonnet-4-6) + Temperature: 0.2, + MaxTokens: 8192, + Expertise: []string{"security", "backend"}, + CommunicationStyle: "concise", + Tools: []string{"Read", "Grep", "Glob", "Bash"}, + ExcludedTools: []string{"Edit", "Write"}, + SystemPrompt: "You are a security expert. Focus on the OWASP Top 10, secret handling, authentication and authorization flaws, and input validation. Assume hostile input.", + Rules: []string{ + "Always check for injection (SQL, command, XSS)", + "Flag hardcoded secrets and weak crypto", + "Verify authentication and authorization on every entry point", + "Check for insecure deserialization and SSRF", + }, + CreatedAt: now, + }, + { + Name: "test-engineer", + Description: "Generates tests and analyzes coverage", + Temperature: 0.3, + MaxTokens: 8192, + Expertise: []string{"testing", "backend"}, + CommunicationStyle: "detailed", + SystemPrompt: "You are a test engineer. Write thorough, maintainable tests that cover happy paths, edge cases, and failure modes. Prefer table-driven tests where idiomatic.", + Rules: []string{ + "Cover happy path, edge cases, and error paths", + "Make tests deterministic and isolated", + "Use table-driven tests where the language supports them", + "Test behavior, not implementation details", + }, + CreatedAt: now, + }, + { + Name: "tracer", + Description: "Debugging and trace analysis specialist", + Temperature: 0.3, + Expertise: []string{"tracing", "testing", "backend"}, + CommunicationStyle: "detailed", + SystemPrompt: "You are an observability specialist. Diagnose issues by analyzing logs, traces, and telemetry. Add instrumentation where visibility is missing.", + Rules: []string{ + "Follow the data: logs, traces, metrics before code", + "Reconstruct the timeline of events", + "Add instrumentation to fill visibility gaps", + "Correlate across services using trace IDs", + }, + CreatedAt: now, + }, + { + Name: "verifier", + Description: "Validates implementations against specifications", + Model: "", // inherit session model (was claude-sonnet-4-6) + Temperature: 0.2, + Expertise: []string{"testing", "backend"}, + CommunicationStyle: "concise", + Tools: []string{"Read", "Grep", "Glob", "Bash"}, + SystemPrompt: "You are a verification specialist. Given a spec and an implementation, confirm whether each acceptance criterion is met. Report concrete pass/fail evidence.", + Rules: []string{ + "Check each acceptance criterion individually", + "Provide evidence for every pass or fail verdict", + "Run the actual tests rather than assuming", + "Report partial completion honestly", + }, + CreatedAt: now, + }, + { + // validator is the read-only half of an implement-then-validate + // agent pair: a separate agent reviews the implementation worker's + // output without the ability to change it. Unlike verifier it is + // ReadOnly (no Bash), so its sign-off cannot be tainted by mutating + // the very code it judges. + Name: "validator", + Description: "Read-only validator of an implementation it did not write", + // Model intentionally left empty: a validator should run on whatever + // model the user has configured for the session rather than pinning a + // specific name that may not exist on their provider. (Several + // built-ins pin claude-sonnet-4-6; this one deliberately inherits.) + Model: "", + Temperature: 0.1, + Expertise: []string{"testing", "backend", "security"}, + CommunicationStyle: "concise", + ReadOnly: true, + Tools: []string{"Read", "Grep", "Glob", "LS"}, + ExcludedTools: []string{"Edit", "Write", "Bash"}, + SystemPrompt: "You are a read-only validation agent. You did not write the code under review and you cannot modify it. Inspect the implementation against the stated expected behavior and report, per acceptance criterion, a concrete PASS or FAIL with file:line evidence. Never assume — cite what you actually read.", + Rules: []string{ + "You are read-only: never propose to edit, write, or run shell commands", + "Cite file:line evidence for every PASS or FAIL", + "Judge against the expected behavior, not your own preferences", + "Report partial or unclear completion honestly rather than rounding up", + }, + CreatedAt: now, + }, + { + Name: "integrator", + Description: "Handles merges, integration, and compatibility", + Temperature: 0.3, + Expertise: []string{"integration", "backend", "devops"}, + CommunicationStyle: "detailed", + SystemPrompt: "You are an integration specialist. Resolve merge conflicts, reconcile interfaces, and ensure components work together. Preserve backward compatibility where required.", + Rules: []string{ + "Preserve backward compatibility unless told otherwise", + "Verify interface contracts on both sides", + "Resolve conflicts by understanding intent, not just text", + "Run integration tests after merging", + }, + CreatedAt: now, + }, + { + Name: "documenter", + Description: "Writes documentation and changelogs", + Temperature: 0.5, + MaxTokens: 16384, + Expertise: []string{"documentation"}, + CommunicationStyle: "tutorial", + SystemPrompt: "You are a technical writer. Produce clear, accurate documentation: READMEs, API docs, changelogs, and inline comments. Write for the reader who knows nothing about the change.", + Rules: []string{ + "Lead with what the reader needs to do", + "Include runnable examples", + "Keep changelogs user-facing, not commit-by-commit", + "Document the 'why' for non-obvious decisions", + }, + CreatedAt: now, + }, + { + Name: "devops", + Description: "CI/CD, deployment, and infrastructure specialist", + Temperature: 0.3, + Expertise: []string{"devops", "backend"}, + CommunicationStyle: "detailed", + SystemPrompt: "You are a DevOps engineer. Handle CI/CD pipelines, containerization, deployment, and infrastructure-as-code. Prioritize reproducibility, security, and observability.", + Rules: []string{ + "Make builds reproducible and cacheable", + "Never bake secrets into images or configs", + "Add health checks and observability hooks", + "Prefer declarative infrastructure-as-code", + }, + CreatedAt: now, + }, + { + Name: "performance", + Description: "Performance profiling and optimization specialist", + Temperature: 0.3, + Expertise: []string{"performance", "backend"}, + CommunicationStyle: "detailed", + SystemPrompt: "You are a performance engineer. Profile before optimizing, measure after. Focus on algorithmic complexity, allocations, and hot paths. Avoid premature optimization.", + Rules: []string{ + "Always measure before and after optimizing", + "Identify the actual bottleneck with profiling", + "Prefer algorithmic improvements over micro-optimizations", + "Document the performance impact with numbers", + }, + CreatedAt: now, + }, + { + Name: "refactorer", + Description: "Code cleanup and refactoring specialist", + Temperature: 0.3, + Expertise: []string{"refactoring", "backend", "frontend"}, + CommunicationStyle: "concise", + SystemPrompt: "You are a refactoring specialist. Improve code structure without changing behavior. Make small, atomic, test-backed changes. Reduce duplication and complexity.", + Rules: []string{ + "Never change behavior during a refactor", + "Make small atomic moves, test after each", + "Reduce duplication and cyclomatic complexity", + "Ensure tests pass before and after every step", + }, + CreatedAt: now, + }, + // --- Cavecrew personas (built into GrayCode Hawk) --- + // Three compact, opinionated personas for multi-agent crews. + // Each enforces a strict output format so downstream agents + // can parse the output mechanically. + { + Name: "cavecrew-investigator", + Description: "Compact code investigator with strict 6-word note format", + Temperature: 0.2, + Expertise: []string{"tracing", "backend", "testing"}, + CommunicationStyle: "concise", + Tools: []string{"Read", "Grep", "Glob", "Bash"}, + SystemPrompt: "You are a code investigator. Read code and produce compact notes in the strict format `path:line — symbol — note` where the note is at most 6 words. Every note MUST follow that exact format. No prose, no explanations, no commentary outside the notes. Maximum 20 notes per response. Each note must be on its own line.", + Rules: []string{ + "Every note MUST be `path:line — symbol — note`", + "Notes are at most 6 words after the dash", + "Never use prose, paragraphs, or headings", + "Skip files that don't relate to the question", + "Order notes by importance, most useful first", + }, + Examples: []PersonaExample{ + { + Input: "Where is the cache invalidated?", + Output: "internal/cache/cache.go:42 — Invalidate() — drops all keys\ninternal/api/handlers.go:88 — put() — calls cache.Invalidate", + Context: "Investigating cache invalidation flow", + }, + }, + CreatedAt: now, + }, + { + Name: "cavecrew-builder", + Description: "Focused implementer that refuses multi-file sprawl", + Temperature: 0.3, + Expertise: []string{"backend", "frontend", "testing"}, + CommunicationStyle: "concise", + SystemPrompt: "You are a focused implementer. Given a single-file scope, write correct, idiomatic code. You HARD-REFUSE to edit 3 or more files in one task; if the work spans more than 2 files, split the work into sub-tasks and ask the caller to assign them. Do not expand scope. Do not refactor adjacent code. Do not add dependencies. Do exactly what the spec says, no more.", + Rules: []string{ + "Hard-refuse tasks that touch 3+ files; ask the caller to split", + "Edit at most 2 files per task", + "Do not refactor code outside the spec", + "Do not add new dependencies without explicit approval", + "Run tests after the change; report pass/fail", + "Stop and ask if the spec is ambiguous", + }, + CreatedAt: now, + }, + { + Name: "cavecrew-reviewer", + Description: "Strict severity-coded reviewer with emoji verdicts", + Temperature: 0.2, + Expertise: []string{"security", "backend", "testing", "refactoring"}, + CommunicationStyle: "concise", + Tools: []string{"Read", "Grep", "Glob", "Bash"}, + ExcludedTools: []string{"Edit", "Write"}, + SystemPrompt: "You are a strict reviewer. Examine the proposed change and report findings using ONLY severity emojis at the start of each line. The four severities are: 🔴 blocker (must fix before merge), 🟡 major (should fix soon), 🔵 minor (nit / style), ❓ question (clarify intent). Each finding is on its own line in the format ` path:line — note`. No prose, no headings, no summary paragraphs. Maximum 30 findings.", + Rules: []string{ + "Every finding MUST start with one of 🔴 🟡 🔵 ❓", + "Format: ` path:line — note`", + "Blockers (🔴) only for security, correctness, or data-loss issues", + "Majors (🟡) for performance, maintainability, or test gaps", + "Minors (🔵) for style, naming, or nitpicks", + "Questions (❓) for ambiguous intent; never assume", + "No prose, no summary, no closing remarks", + }, + Examples: []PersonaExample{ + { + Input: "Review the auth refactor in PR #42", + Output: "🔴 internal/auth/jwt.go:18 — signature never expires, no exp claim\n🟡 internal/auth/jwt.go:55 — error message leaks signing key prefix\n🔵 internal/auth/jwt.go:1 — package comment missing\n❓ internal/auth/jwt.go:30 — why HS256 instead of RS256?", + Context: "Reviewing JWT auth refactor", + }, + }, + CreatedAt: now, + }, + } +} + +// CavecrewPersonas returns just the three cavecrew personas +// (investigator, builder, reviewer) built into GrayCode Hawk. +// These are a strict, format-driven subset of the full BuiltinPersonas +// list; callers that want only the cavecrew subset can use this +// function instead of BuiltinPersonas. +func CavecrewPersonas() []*Persona { + now := time.Now() + return []*Persona{ + { + Name: "cavecrew-investigator", + Description: "Compact code investigator with strict 6-word note format", + Temperature: 0.2, + Expertise: []string{"tracing", "backend", "testing"}, + CommunicationStyle: "concise", + Tools: []string{"Read", "Grep", "Glob", "Bash"}, + SystemPrompt: "You are a code investigator. Read code and produce compact notes in the strict format `path:line — symbol — note` where the note is at most 6 words. Every note MUST follow that exact format. No prose, no explanations, no commentary outside the notes. Maximum 20 notes per response. Each note must be on its own line.", + Rules: []string{ + "Every note MUST be `path:line — symbol — note`", + "Notes are at most 6 words after the dash", + "Never use prose, paragraphs, or headings", + "Skip files that don't relate to the question", + "Order notes by importance, most useful first", + }, + Examples: []PersonaExample{ + { + Input: "Where is the cache invalidated?", + Output: "internal/cache/cache.go:42 — Invalidate() — drops all keys\ninternal/api/handlers.go:88 — put() — calls cache.Invalidate", + Context: "Investigating cache invalidation flow", + }, + }, + CreatedAt: now, + }, + { + Name: "cavecrew-builder", + Description: "Focused implementer that refuses multi-file sprawl", + Temperature: 0.3, + Expertise: []string{"backend", "frontend", "testing"}, + CommunicationStyle: "concise", + SystemPrompt: "You are a focused implementer. Given a single-file scope, write correct, idiomatic code. You HARD-REFUSE to edit 3 or more files in one task; if the work spans more than 2 files, split the work into sub-tasks and ask the caller to assign them. Do not expand scope. Do not refactor adjacent code. Do not add dependencies. Do exactly what the spec says, no more.", + Rules: []string{ + "Hard-refuse tasks that touch 3+ files; ask the caller to split", + "Edit at most 2 files per task", + "Do not refactor code outside the spec", + "Do not add new dependencies without explicit approval", + "Run tests after the change; report pass/fail", + "Stop and ask if the spec is ambiguous", + }, + CreatedAt: now, + }, + { + Name: "cavecrew-reviewer", + Description: "Strict severity-coded reviewer with emoji verdicts", + Temperature: 0.2, + Expertise: []string{"security", "backend", "testing", "refactoring"}, + CommunicationStyle: "concise", + Tools: []string{"Read", "Grep", "Glob", "Bash"}, + ExcludedTools: []string{"Edit", "Write"}, + SystemPrompt: "You are a strict reviewer. Examine the proposed change and report findings using ONLY severity emojis at the start of each line. The four severities are: 🔴 blocker (must fix before merge), 🟡 major (should fix soon), 🔵 minor (nit / style), ❓ question (clarify intent). Each finding is on its own line in the format ` path:line — note`. No prose, no headings, no summary paragraphs. Maximum 30 findings.", + Rules: []string{ + "Every finding MUST start with one of 🔴 🟡 🔵 ❓", + "Format: ` path:line — note`", + "Blockers (🔴) only for security, correctness, or data-loss issues", + "Majors (🟡) for performance, maintainability, or test gaps", + "Minors (🔵) for style, naming, or nitpicks", + "Questions (❓) for ambiguous intent; never assume", + "No prose, no summary, no closing remarks", + }, + Examples: []PersonaExample{ + { + Input: "Review the auth refactor in PR #42", + Output: "🔴 internal/auth/jwt.go:18 — signature never expires, no exp claim\n🟡 internal/auth/jwt.go:55 — error message leaks signing key prefix\n🔵 internal/auth/jwt.go:1 — package comment missing\n❓ internal/auth/jwt.go:30 — why HS256 instead of RS256?", + Context: "Reviewing JWT auth refactor", + }, + }, + CreatedAt: now, + }, + } +} From 7701e30cb969ca414fa0336e14d7b76d7c81a168 Mon Sep 17 00:00:00 2001 From: Lakshman Patel Date: Fri, 19 Jun 2026 18:55:47 +0530 Subject: [PATCH 04/20] refactor(fingerprint): split project.go into project_detect.go and project_conventions.go --- internal/feature/fingerprint/project.go | 1133 ----------------- .../fingerprint/project_conventions.go | 439 +++++++ .../feature/fingerprint/project_detect.go | 718 +++++++++++ 3 files changed, 1157 insertions(+), 1133 deletions(-) create mode 100644 internal/feature/fingerprint/project_conventions.go create mode 100644 internal/feature/fingerprint/project_detect.go diff --git a/internal/feature/fingerprint/project.go b/internal/feature/fingerprint/project.go index 376392a4..479e5b30 100644 --- a/internal/feature/fingerprint/project.go +++ b/internal/feature/fingerprint/project.go @@ -1,16 +1,9 @@ package fingerprint import ( - "bufio" - "context" - "encoding/json" "fmt" - "io/fs" "os" - "os/exec" "path/filepath" - "regexp" - "sort" "strings" ) @@ -116,1132 +109,6 @@ func Scan(projectDir string) (*ProjectFingerprint, error) { return fp, nil } -// detectLanguages walks the project directory, maps extensions to languages, -// calculates percentages, and returns results sorted by file count descending. -func detectLanguages(dir string) []ProjectLangInfo { - counts := make(map[string]int) - - _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { - if err != nil { - return nil - } - if d.IsDir() { - if skipDirs[d.Name()] { - return filepath.SkipDir - } - return nil - } - if !d.Type().IsRegular() { - return nil - } - - ext := filepath.Ext(path) - if lang, ok := extToLang[ext]; ok { - counts[lang]++ - } - return nil - }) - - if len(counts) == 0 { - return nil - } - - total := 0 - for _, c := range counts { - total += c - } - - langs := make([]ProjectLangInfo, 0, len(counts)) - for name, count := range counts { - pct := float64(count) / float64(total) * 100 - langs = append(langs, ProjectLangInfo{ - Name: name, - FileCount: count, - Percentage: pct, - }) - } - - sort.Slice(langs, func(i, j int) bool { - return langs[i].FileCount > langs[j].FileCount - }) - - return langs -} - -// detectFramework reads key config files to determine the web/app framework. -func detectFramework(dir string, primaryLang string) string { - switch primaryLang { - case "Go": - return detectGoFramework(dir) - case "Python": - return detectPythonFramework(dir) - case "JavaScript", "TypeScript": - return detectJSFramework(dir) - case "Rust": - return detectRustFramework(dir) - } - return "" -} - -// detectGoFramework reads go.mod for known Go web frameworks. -func detectGoFramework(dir string) string { - goModPath := filepath.Join(dir, "go.mod") - data, err := os.ReadFile(goModPath) - if err != nil { - return "" - } - content := string(data) - - frameworks := []struct { - module string - name string - }{ - {"github.com/go-chi/chi", "chi"}, - {"github.com/gin-gonic/gin", "gin"}, - {"github.com/labstack/echo", "echo"}, - {"github.com/gofiber/fiber", "fiber"}, - {"github.com/gorilla/mux", "gorilla"}, - {"github.com/julienschmidt/httprouter", "httprouter"}, - {"github.com/valyala/fasthttp", "fasthttp"}, - } - - for _, fw := range frameworks { - if strings.Contains(content, fw.module) { - return fw.name - } - } - - // Check for net/http usage (fallback — it's in stdlib so not in go.mod). - return "" -} - -// detectPythonFramework reads requirements.txt, Pipfile, or pyproject.toml. -func detectPythonFramework(dir string) string { - files := []string{ - filepath.Join(dir, "requirements.txt"), - filepath.Join(dir, "Pipfile"), - filepath.Join(dir, "pyproject.toml"), - filepath.Join(dir, "setup.py"), - } - - frameworks := []struct { - keyword string - name string - }{ - {"django", "django"}, - {"Django", "django"}, - {"flask", "flask"}, - {"Flask", "flask"}, - {"fastapi", "fastapi"}, - {"FastAPI", "fastapi"}, - {"tornado", "tornado"}, - {"starlette", "starlette"}, - {"sanic", "sanic"}, - } - - for _, f := range files { - data, err := os.ReadFile(f) - if err != nil { - continue - } - content := string(data) - for _, fw := range frameworks { - if strings.Contains(content, fw.keyword) { - return fw.name - } - } - } - - return "" -} - -// detectJSFramework reads package.json for known JS/TS frameworks. -func detectJSFramework(dir string) string { - pkgPath := filepath.Join(dir, "package.json") - data, err := os.ReadFile(pkgPath) - if err != nil { - return "" - } - - var pkg struct { - Dependencies map[string]interface{} `json:"dependencies"` - DevDependencies map[string]interface{} `json:"devDependencies"` - } - if err := json.Unmarshal(data, &pkg); err != nil { - return "" - } - - // Merge deps for lookup. - allDeps := make(map[string]bool) - for k := range pkg.Dependencies { - allDeps[k] = true - } - for k := range pkg.DevDependencies { - allDeps[k] = true - } - - frameworks := []struct { - pkg string - name string - }{ - {"next", "next.js"}, - {"nuxt", "nuxt"}, - {"@angular/core", "angular"}, - {"vue", "vue"}, - {"svelte", "svelte"}, - {"express", "express"}, - {"fastify", "fastify"}, - {"koa", "koa"}, - {"hapi", "hapi"}, - {"react", "react"}, - {"gatsby", "gatsby"}, - {"remix", "remix"}, - } - - for _, fw := range frameworks { - if allDeps[fw.pkg] { - return fw.name - } - } - - return "" -} - -// detectRustFramework reads Cargo.toml for known Rust web frameworks. -func detectRustFramework(dir string) string { - cargoPath := filepath.Join(dir, "Cargo.toml") - data, err := os.ReadFile(cargoPath) - if err != nil { - return "" - } - content := string(data) - - frameworks := []struct { - crate string - name string - }{ - {"actix-web", "actix"}, - {"rocket", "rocket"}, - {"axum", "axum"}, - {"warp", "warp"}, - {"tide", "tide"}, - } - - for _, fw := range frameworks { - if strings.Contains(content, fw.crate) { - return fw.name - } - } - - return "" -} - -// detectBuildSystem determines the project's build system from manifest files. -func detectBuildSystem(dir string) string { - buildSystems := []struct { - file string - system string - }{ - {"go.mod", "go modules"}, - {"package.json", "npm"}, - {"Cargo.toml", "cargo"}, - {"pom.xml", "maven"}, - {"build.gradle", "gradle"}, - {"build.gradle.kts", "gradle"}, - {"CMakeLists.txt", "cmake"}, - {"Makefile", "make"}, - {"meson.build", "meson"}, - {"BUILD", "bazel"}, - {"WORKSPACE", "bazel"}, - {"mix.exs", "mix"}, - {"pubspec.yaml", "pub"}, - {"Package.swift", "swift package manager"}, - {"Rakefile", "rake"}, - } - - for _, bs := range buildSystems { - path := filepath.Join(dir, bs.file) - if _, err := os.Stat(path); err == nil { - return bs.system - } - } - - return "" -} - -// detectProjectPackageManager determines the project's package manager. -func detectProjectPackageManager(dir string) string { - managers := []struct { - file string - manager string - }{ - {"pnpm-lock.yaml", "pnpm"}, - {"yarn.lock", "yarn"}, - {"package-lock.json", "npm"}, - {"bun.lockb", "bun"}, - {"go.sum", "go modules"}, - {"Cargo.lock", "cargo"}, - {"Pipfile.lock", "pipenv"}, - {"poetry.lock", "poetry"}, - {"Gemfile.lock", "bundler"}, - {"composer.lock", "composer"}, - {"pubspec.lock", "pub"}, - {"mix.lock", "mix"}, - } - - for _, m := range managers { - path := filepath.Join(dir, m.file) - if _, err := os.Stat(path); err == nil { - return m.manager - } - } - - // Fallback: check manifest files. - fallbacks := []struct { - file string - manager string - }{ - {"go.mod", "go modules"}, - {"package.json", "npm"}, - {"Cargo.toml", "cargo"}, - {"requirements.txt", "pip"}, - {"pyproject.toml", "pip"}, - {"Gemfile", "bundler"}, - {"composer.json", "composer"}, - } - - for _, f := range fallbacks { - path := filepath.Join(dir, f.file) - if _, err := os.Stat(path); err == nil { - return f.manager - } - } - - return "" -} - -// detectTestFramework determines the test framework used. -func detectTestFramework(dir string, lang string) string { - switch lang { - case "Go": - // Go has a built-in test framework. - // Check for testify or other test libs in go.mod. - goModPath := filepath.Join(dir, "go.mod") - if data, err := os.ReadFile(goModPath); err == nil { - content := string(data) - if strings.Contains(content, "github.com/stretchr/testify") { - return "go test + testify" - } - } - // Check if there are any _test.go files. - hasTests := false - _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { - if err != nil || hasTests { - return filepath.SkipAll - } - if d.IsDir() && skipDirs[d.Name()] { - return filepath.SkipDir - } - if !d.IsDir() && strings.HasSuffix(d.Name(), "_test.go") { - hasTests = true - return filepath.SkipAll - } - return nil - }) - if hasTests { - return "go test" - } - return "" - - case "JavaScript", "TypeScript": - pkgPath := filepath.Join(dir, "package.json") - data, err := os.ReadFile(pkgPath) - if err != nil { - return "" - } - var pkg struct { - Dependencies map[string]interface{} `json:"dependencies"` - DevDependencies map[string]interface{} `json:"devDependencies"` - Scripts map[string]string `json:"scripts"` - } - if err := json.Unmarshal(data, &pkg); err != nil { - return "" - } - - allDeps := make(map[string]bool) - for k := range pkg.Dependencies { - allDeps[k] = true - } - for k := range pkg.DevDependencies { - allDeps[k] = true - } - - if allDeps["vitest"] { - return "vitest" - } - if allDeps["jest"] || allDeps["@jest/core"] || allDeps["ts-jest"] { - return "jest" - } - if allDeps["mocha"] { - return "mocha" - } - if allDeps["ava"] { - return "ava" - } - if allDeps["cypress"] { - return "cypress" - } - if allDeps["playwright"] || allDeps["@playwright/test"] { - return "playwright" - } - return "" - - case "Python": - // Check for pytest in requirements or installed. - files := []string{ - filepath.Join(dir, "requirements.txt"), - filepath.Join(dir, "requirements-dev.txt"), - filepath.Join(dir, "pyproject.toml"), - filepath.Join(dir, "setup.cfg"), - filepath.Join(dir, "Pipfile"), - } - for _, f := range files { - data, err := os.ReadFile(f) - if err != nil { - continue - } - content := string(data) - if strings.Contains(content, "pytest") { - return "pytest" - } - } - // Check for pytest.ini or conftest.py. - if _, err := os.Stat(filepath.Join(dir, "pytest.ini")); err == nil { - return "pytest" - } - if _, err := os.Stat(filepath.Join(dir, "conftest.py")); err == nil { - return "pytest" - } - // Check for test files with unittest patterns. - hasUnittest := false - _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { - if err != nil || hasUnittest { - return filepath.SkipAll - } - if d.IsDir() && skipDirs[d.Name()] { - return filepath.SkipDir - } - if !d.IsDir() && (strings.HasPrefix(d.Name(), "test_") || strings.HasSuffix(d.Name(), "_test.py")) { - hasUnittest = true - return filepath.SkipAll - } - return nil - }) - if hasUnittest { - return "unittest" - } - return "" - - case "Rust": - // Rust uses built-in test framework via cargo test. - cargoPath := filepath.Join(dir, "Cargo.toml") - if _, err := os.Stat(cargoPath); err == nil { - return "cargo test" - } - return "" - - case "Java": - pomPath := filepath.Join(dir, "pom.xml") - if data, err := os.ReadFile(pomPath); err == nil { - content := string(data) - if strings.Contains(content, "junit") || strings.Contains(content, "JUnit") { - return "junit" - } - if strings.Contains(content, "testng") { - return "testng" - } - } - gradlePath := filepath.Join(dir, "build.gradle") - if data, err := os.ReadFile(gradlePath); err == nil { - content := string(data) - if strings.Contains(content, "junit") || strings.Contains(content, "JUnit") { - return "junit" - } - if strings.Contains(content, "testng") { - return "testng" - } - } - return "" - - case "Ruby": - if _, err := os.Stat(filepath.Join(dir, "spec")); err == nil { - return "rspec" - } - gemPath := filepath.Join(dir, "Gemfile") - if data, err := os.ReadFile(gemPath); err == nil { - content := string(data) - if strings.Contains(content, "rspec") { - return "rspec" - } - if strings.Contains(content, "minitest") { - return "minitest" - } - } - return "" - } - - return "" -} - -// detectLintTools detects configured linting tools from config files. -func detectLintTools(dir string) []string { - var tools []string - - lintConfigs := []struct { - file string - tool string - }{ - {".golangci.yml", "golangci-lint"}, - {".golangci.yaml", "golangci-lint"}, - {".golangci.toml", "golangci-lint"}, - {".eslintrc", "eslint"}, - {".eslintrc.js", "eslint"}, - {".eslintrc.json", "eslint"}, - {".eslintrc.yml", "eslint"}, - {"eslint.config.js", "eslint"}, - {"eslint.config.mjs", "eslint"}, - {".prettierrc", "prettier"}, - {".prettierrc.js", "prettier"}, - {".prettierrc.json", "prettier"}, - {"prettier.config.js", "prettier"}, - {".stylelintrc", "stylelint"}, - {".stylelintrc.json", "stylelint"}, - {"stylelint.config.js", "stylelint"}, - {".flake8", "flake8"}, - {"setup.cfg", "flake8"}, // might contain flake8 config - {".pylintrc", "pylint"}, - {"pyproject.toml", "ruff"}, // might contain ruff config - {".rubocop.yml", "rubocop"}, - {"clippy.toml", "clippy"}, - {".clippy.toml", "clippy"}, - {".editorconfig", "editorconfig"}, - {"biome.json", "biome"}, - {"deno.json", "deno lint"}, - {".hadolint.yaml", "hadolint"}, - {".shellcheckrc", "shellcheck"}, - } - - seen := make(map[string]bool) - for _, lc := range lintConfigs { - path := filepath.Join(dir, lc.file) - if _, err := os.Stat(path); err == nil { - // Special case: setup.cfg / pyproject.toml may or may not contain lint config. - if lc.file == "setup.cfg" { - if data, err := os.ReadFile(path); err == nil { - if !strings.Contains(string(data), "[flake8]") { - continue - } - } - } - if lc.file == "pyproject.toml" { - if data, err := os.ReadFile(path); err == nil { - if !strings.Contains(string(data), "[tool.ruff]") && !strings.Contains(string(data), "ruff") { - continue - } - } - } - if !seen[lc.tool] { - seen[lc.tool] = true - tools = append(tools, lc.tool) - } - } - } - - // Check package.json for lint-related devDependencies. - pkgPath := filepath.Join(dir, "package.json") - if data, err := os.ReadFile(pkgPath); err == nil { - var pkg struct { - DevDependencies map[string]interface{} `json:"devDependencies"` - } - if err := json.Unmarshal(data, &pkg); err == nil { - jsDeps := []struct { - pkg string - tool string - }{ - {"eslint", "eslint"}, - {"prettier", "prettier"}, - {"stylelint", "stylelint"}, - {"biome", "biome"}, - {"@biomejs/biome", "biome"}, - } - for _, jd := range jsDeps { - if _, ok := pkg.DevDependencies[jd.pkg]; ok && !seen[jd.tool] { - seen[jd.tool] = true - tools = append(tools, jd.tool) - } - } - } - } - - return tools -} - -// detectCISystem identifies the CI/CD system in use and returns its name. -func detectCISystem(dir string) string { - ciSystems := []struct { - path string - isDir bool - system string - }{ - {filepath.Join(".github", "workflows"), true, "github-actions"}, - {".gitlab-ci.yml", false, "gitlab-ci"}, - {".circleci", true, "circleci"}, - {"Jenkinsfile", false, "jenkins"}, - {".travis.yml", false, "travis-ci"}, - {"azure-pipelines.yml", false, "azure-pipelines"}, - {"bitbucket-pipelines.yml", false, "bitbucket-pipelines"}, - {".drone.yml", false, "drone"}, - {".buildkite", true, "buildkite"}, - {"cloudbuild.yaml", false, "google-cloud-build"}, - {"cloudbuild.yml", false, "google-cloud-build"}, - {".tekton", true, "tekton"}, - } - - for _, ci := range ciSystems { - full := filepath.Join(dir, ci.path) - info, err := os.Stat(full) - if err != nil { - continue - } - if ci.isDir && info.IsDir() { - return ci.system - } - if !ci.isDir && !info.IsDir() { - return ci.system - } - } - - return "" -} - -// detectDocker checks for Dockerfile or docker-compose files. -func detectDocker(dir string) bool { - dockerFiles := []string{ - "Dockerfile", - "dockerfile", - "docker-compose.yml", - "docker-compose.yaml", - "compose.yml", - "compose.yaml", - ".dockerignore", - } - - for _, f := range dockerFiles { - path := filepath.Join(dir, f) - if _, err := os.Stat(path); err == nil { - return true - } - } - - return false -} - -// detectMonorepo checks for indicators of a monorepo structure. -func detectMonorepo(dir string) bool { - // Check for multiple go.mod files. - goModCount := 0 - _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { - if err != nil { - return nil - } - if d.IsDir() && skipDirs[d.Name()] { - return filepath.SkipDir - } - if !d.IsDir() && d.Name() == "go.mod" { - goModCount++ - if goModCount > 1 { - return filepath.SkipAll - } - } - return nil - }) - if goModCount > 1 { - return true - } - - // Check for go.work file. - if _, err := os.Stat(filepath.Join(dir, "go.work")); err == nil { - return true - } - - // Check for packages/ or apps/ directory (JS monorepos). - monoDirs := []string{"packages", "apps", "modules", "services", "libs"} - for _, md := range monoDirs { - path := filepath.Join(dir, md) - if info, err := os.Stat(path); err == nil && info.IsDir() { - return true - } - } - - // Check for lerna.json, pnpm-workspace.yaml, or turbo.json. - monoFiles := []string{"lerna.json", "pnpm-workspace.yaml", "turbo.json", "nx.json"} - for _, mf := range monoFiles { - path := filepath.Join(dir, mf) - if _, err := os.Stat(path); err == nil { - return true - } - } - - // Check package.json for workspaces field. - pkgPath := filepath.Join(dir, "package.json") - if data, err := os.ReadFile(pkgPath); err == nil { - var pkg map[string]interface{} - if err := json.Unmarshal(data, &pkg); err == nil { - if _, ok := pkg["workspaces"]; ok { - return true - } - } - } - - return false -} - -// classifyProjectSize returns a size classification based on file count. -func classifyProjectSize(fileCount int) string { - switch { - case fileCount < 10: - return "tiny" - case fileCount < 100: - return "small" - case fileCount <= 1000: - return "medium" - default: - return "large" - } -} - -// detectConventions analyzes the project to identify coding conventions. -func detectConventions(dir string, lang string) []Convention { - var conventions []Convention - - // Detect indentation from .editorconfig. - if conv := detectIndentationConvention(dir); conv != nil { - conventions = append(conventions, *conv) - } - - // Detect naming convention by sampling source files. - if conv := detectNamingConvention(dir, lang); conv != nil { - conventions = append(conventions, *conv) - } - - // Detect error handling style (Go-specific). - if lang == "Go" { - if conv := detectGoErrorHandling(dir); conv != nil { - conventions = append(conventions, *conv) - } - } - - // Detect import organization. - if conv := detectImportOrganization(dir, lang); conv != nil { - conventions = append(conventions, *conv) - } - - // Detect test naming convention. - if conv := detectTestNaming(dir, lang); conv != nil { - conventions = append(conventions, *conv) - } - - // Detect commit message style. - if conv := detectCommitStyle(dir); conv != nil { - conventions = append(conventions, *conv) - } - - return conventions -} - -// detectIndentationConvention reads .editorconfig or samples files. -func detectIndentationConvention(dir string) *Convention { - // Check .editorconfig first. - editorConfigPath := filepath.Join(dir, ".editorconfig") - if data, err := os.ReadFile(editorConfigPath); err == nil { - content := strings.ToLower(string(data)) - if strings.Contains(content, "indent_style = tab") { - return &Convention{ - Name: "indentation", - Description: "Tabs for indentation", - Confidence: 1.0, - } - } - if strings.Contains(content, "indent_style = space") { - // Try to find indent_size. - size := "unknown" - lines := strings.Split(content, "\n") - for _, line := range lines { - line = strings.TrimSpace(line) - if strings.HasPrefix(line, "indent_size") { - parts := strings.SplitN(line, "=", 2) - if len(parts) == 2 { - size = strings.TrimSpace(parts[1]) - } - } - } - desc := "Spaces for indentation" - if size != "unknown" { - desc = fmt.Sprintf("%s-space indentation", size) - } - return &Convention{ - Name: "indentation", - Description: desc, - Confidence: 1.0, - } - } - } - - // Sample source files to detect indentation. - tabCount := 0 - spaceCount := 0 - sampled := 0 - maxSamples := 20 - - _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { - if err != nil || sampled >= maxSamples { - return filepath.SkipAll - } - if d.IsDir() { - if skipDirs[d.Name()] { - return filepath.SkipDir - } - return nil - } - - ext := filepath.Ext(path) - if _, ok := extToLang[ext]; !ok { - return nil - } - - f, err := os.Open(path) - if err != nil { - return nil - } - defer func() { _ = f.Close() }() - - scanner := bufio.NewScanner(f) - lineCount := 0 - for scanner.Scan() && lineCount < 50 { - line := scanner.Text() - if len(line) > 0 { - if line[0] == '\t' { - tabCount++ - } else if line[0] == ' ' && len(line) > 1 && line[1] == ' ' { - spaceCount++ - } - } - lineCount++ - } - sampled++ - return nil - }) - - total := tabCount + spaceCount - if total == 0 { - return nil - } - - if tabCount > spaceCount { - confidence := float64(tabCount) / float64(total) - return &Convention{ - Name: "indentation", - Description: "Tabs for indentation", - Confidence: confidence, - } - } - confidence := float64(spaceCount) / float64(total) - return &Convention{ - Name: "indentation", - Description: "Spaces for indentation", - Confidence: confidence, - } -} - -// detectNamingConvention samples identifiers to determine naming style. -func detectNamingConvention(dir string, lang string) *Convention { - // For Go, the convention is well-known: exported = PascalCase, local = camelCase. - if lang == "Go" { - return &Convention{ - Name: "naming", - Description: "camelCase/PascalCase (Go standard)", - Confidence: 1.0, - } - } - - // For Python, sample for snake_case vs camelCase. - if lang == "Python" { - snakeCount := 0 - camelCount := 0 - sampled := 0 - - snakeRe := regexp.MustCompile(`\bdef ([a-z][a-z0-9]*_[a-z0-9_]+)\b`) - camelRe := regexp.MustCompile(`\bdef ([a-z][a-zA-Z0-9]+[A-Z][a-zA-Z0-9]*)\b`) - - _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { - if err != nil || sampled >= 10 { - return filepath.SkipAll - } - if d.IsDir() && skipDirs[d.Name()] { - return filepath.SkipDir - } - if d.IsDir() || filepath.Ext(path) != ".py" { - return nil - } - - data, err := os.ReadFile(path) - if err != nil { - return nil - } - content := string(data) - snakeCount += len(snakeRe.FindAllString(content, -1)) - camelCount += len(camelRe.FindAllString(content, -1)) - sampled++ - return nil - }) - - total := snakeCount + camelCount - if total == 0 { - return nil - } - if snakeCount > camelCount { - return &Convention{ - Name: "naming", - Description: "snake_case (Python standard)", - Confidence: float64(snakeCount) / float64(total), - } - } - return &Convention{ - Name: "naming", - Description: "camelCase", - Confidence: float64(camelCount) / float64(total), - } - } - - return nil -} - -// detectGoErrorHandling checks error handling patterns in Go source files. -func detectGoErrorHandling(dir string) *Convention { - wrapCount := 0 // fmt.Errorf("...: %w", err) - bareCount := 0 // return err (without wrapping) - sampled := 0 - - wrapRe := regexp.MustCompile(`fmt\.Errorf\([^)]*%w`) - bareRe := regexp.MustCompile(`return\s+err\b`) - - _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { - if err != nil || sampled >= 20 { - return filepath.SkipAll - } - if d.IsDir() && skipDirs[d.Name()] { - return filepath.SkipDir - } - if d.IsDir() || !strings.HasSuffix(d.Name(), ".go") || strings.HasSuffix(d.Name(), "_test.go") { - return nil - } - - data, err := os.ReadFile(path) - if err != nil { - return nil - } - content := string(data) - wrapCount += len(wrapRe.FindAllString(content, -1)) - bareCount += len(bareRe.FindAllString(content, -1)) - sampled++ - return nil - }) - - total := wrapCount + bareCount - if total == 0 { - return nil - } - - if wrapCount > bareCount { - return &Convention{ - Name: "error-handling", - Description: "Error wrapping with %w", - Confidence: float64(wrapCount) / float64(total), - } - } - return &Convention{ - Name: "error-handling", - Description: "Bare error returns", - Confidence: float64(bareCount) / float64(total), - } -} - -// detectImportOrganization checks if imports are grouped (stdlib vs third-party). -func detectImportOrganization(dir string, lang string) *Convention { - if lang != "Go" { - return nil - } - - groupedCount := 0 - ungroupedCount := 0 - sampled := 0 - - _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { - if err != nil || sampled >= 15 { - return filepath.SkipAll - } - if d.IsDir() && skipDirs[d.Name()] { - return filepath.SkipDir - } - if d.IsDir() || !strings.HasSuffix(d.Name(), ".go") || strings.HasSuffix(d.Name(), "_test.go") { - return nil - } - - data, err := os.ReadFile(path) - if err != nil { - return nil - } - content := string(data) - - // Find import blocks. - importStart := strings.Index(content, "import (") - if importStart == -1 { - return nil - } - importEnd := strings.Index(content[importStart:], ")") - if importEnd == -1 { - return nil - } - importBlock := content[importStart : importStart+importEnd] - - // Check for blank lines within the import block (indicating grouping). - if strings.Contains(importBlock, "\n\n") { - groupedCount++ - } else { - // Only count as ungrouped if there are multiple imports. - lines := strings.Split(importBlock, "\n") - importLines := 0 - for _, l := range lines { - l = strings.TrimSpace(l) - if l != "" && l != "import (" && l != ")" && !strings.HasPrefix(l, "//") { - importLines++ - } - } - if importLines > 1 { - ungroupedCount++ - } - } - sampled++ - return nil - }) - - total := groupedCount + ungroupedCount - if total == 0 { - return nil - } - - if groupedCount > ungroupedCount { - return &Convention{ - Name: "imports", - Description: "Grouped imports (stdlib separated from third-party)", - Confidence: float64(groupedCount) / float64(total), - } - } - return &Convention{ - Name: "imports", - Description: "Ungrouped imports", - Confidence: float64(ungroupedCount) / float64(total), - } -} - -// detectTestNaming checks test naming conventions. -func detectTestNaming(dir string, lang string) *Convention { - if lang != "Go" { - return nil - } - - // Check for table-driven tests vs simple tests. - tableDrivenCount := 0 - simpleCount := 0 - sampled := 0 - - tableDrivenRe := regexp.MustCompile(`(tests|cases|testCases|tt)\s*:?=\s*\[\]struct`) - simpleFuncRe := regexp.MustCompile(`func Test[A-Z]\w+\(t \*testing\.T\)`) - - _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { - if err != nil || sampled >= 15 { - return filepath.SkipAll - } - if d.IsDir() && skipDirs[d.Name()] { - return filepath.SkipDir - } - if d.IsDir() || !strings.HasSuffix(d.Name(), "_test.go") { - return nil - } - - data, err := os.ReadFile(path) - if err != nil { - return nil - } - content := string(data) - tableDrivenCount += len(tableDrivenRe.FindAllString(content, -1)) - simpleCount += len(simpleFuncRe.FindAllString(content, -1)) - sampled++ - return nil - }) - - if tableDrivenCount > 0 && simpleCount > 0 { - total := tableDrivenCount + simpleCount - if tableDrivenCount > simpleCount/2 { - return &Convention{ - Name: "test-style", - Description: "Table-driven tests", - Confidence: float64(tableDrivenCount) / float64(total), - } - } - } - - return nil -} - -// detectCommitStyle checks git log for conventional commits or other patterns. -func detectCommitStyle(dir string) *Convention { - cmd := exec.CommandContext(context.Background(), "git", "log", "--oneline", "-20", "--format=%s") - cmd.Dir = dir - out, err := cmd.Output() - if err != nil { - return nil - } - - lines := strings.Split(strings.TrimSpace(string(out)), "\n") - if len(lines) == 0 { - return nil - } - - // Check for conventional commits (feat:, fix:, chore:, etc.). - conventionalRe := regexp.MustCompile(`^(feat|fix|chore|docs|style|refactor|perf|test|build|ci|revert)(\(.+\))?:`) - conventionalCount := 0 - - for _, line := range lines { - if conventionalRe.MatchString(line) { - conventionalCount++ - } - } - - if conventionalCount > 0 { - confidence := float64(conventionalCount) / float64(len(lines)) - if confidence >= 0.3 { - return &Convention{ - Name: "commit-style", - Description: "Conventional commits (feat:, fix:, etc.)", - Confidence: confidence, - } - } - } - - return nil -} - // generateRecommendations produces hawk configuration suggestions based on the // detected project fingerprint. func generateRecommendations(fp *ProjectFingerprint) []string { diff --git a/internal/feature/fingerprint/project_conventions.go b/internal/feature/fingerprint/project_conventions.go new file mode 100644 index 00000000..55949c4b --- /dev/null +++ b/internal/feature/fingerprint/project_conventions.go @@ -0,0 +1,439 @@ +package fingerprint + +import ( + "bufio" + "context" + "fmt" + "io/fs" + "os" + "os/exec" + "path/filepath" + "regexp" + "strings" +) + +// This file holds the coding-convention detectors used by Scan (indentation, +// naming, error handling, import organization, test style, commit style). The +// language/build detectors live in project_detect.go. + +// detectConventions analyzes the project to identify coding conventions. +func detectConventions(dir string, lang string) []Convention { + var conventions []Convention + + // Detect indentation from .editorconfig. + if conv := detectIndentationConvention(dir); conv != nil { + conventions = append(conventions, *conv) + } + + // Detect naming convention by sampling source files. + if conv := detectNamingConvention(dir, lang); conv != nil { + conventions = append(conventions, *conv) + } + + // Detect error handling style (Go-specific). + if lang == "Go" { + if conv := detectGoErrorHandling(dir); conv != nil { + conventions = append(conventions, *conv) + } + } + + // Detect import organization. + if conv := detectImportOrganization(dir, lang); conv != nil { + conventions = append(conventions, *conv) + } + + // Detect test naming convention. + if conv := detectTestNaming(dir, lang); conv != nil { + conventions = append(conventions, *conv) + } + + // Detect commit message style. + if conv := detectCommitStyle(dir); conv != nil { + conventions = append(conventions, *conv) + } + + return conventions +} + +// detectIndentationConvention reads .editorconfig or samples files. +func detectIndentationConvention(dir string) *Convention { + // Check .editorconfig first. + editorConfigPath := filepath.Join(dir, ".editorconfig") + if data, err := os.ReadFile(editorConfigPath); err == nil { + content := strings.ToLower(string(data)) + if strings.Contains(content, "indent_style = tab") { + return &Convention{ + Name: "indentation", + Description: "Tabs for indentation", + Confidence: 1.0, + } + } + if strings.Contains(content, "indent_style = space") { + // Try to find indent_size. + size := "unknown" + lines := strings.Split(content, "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "indent_size") { + parts := strings.SplitN(line, "=", 2) + if len(parts) == 2 { + size = strings.TrimSpace(parts[1]) + } + } + } + desc := "Spaces for indentation" + if size != "unknown" { + desc = fmt.Sprintf("%s-space indentation", size) + } + return &Convention{ + Name: "indentation", + Description: desc, + Confidence: 1.0, + } + } + } + + // Sample source files to detect indentation. + tabCount := 0 + spaceCount := 0 + sampled := 0 + maxSamples := 20 + + _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { + if err != nil || sampled >= maxSamples { + return filepath.SkipAll + } + if d.IsDir() { + if skipDirs[d.Name()] { + return filepath.SkipDir + } + return nil + } + + ext := filepath.Ext(path) + if _, ok := extToLang[ext]; !ok { + return nil + } + + f, err := os.Open(path) + if err != nil { + return nil + } + defer func() { _ = f.Close() }() + + scanner := bufio.NewScanner(f) + lineCount := 0 + for scanner.Scan() && lineCount < 50 { + line := scanner.Text() + if len(line) > 0 { + if line[0] == '\t' { + tabCount++ + } else if line[0] == ' ' && len(line) > 1 && line[1] == ' ' { + spaceCount++ + } + } + lineCount++ + } + sampled++ + return nil + }) + + total := tabCount + spaceCount + if total == 0 { + return nil + } + + if tabCount > spaceCount { + confidence := float64(tabCount) / float64(total) + return &Convention{ + Name: "indentation", + Description: "Tabs for indentation", + Confidence: confidence, + } + } + confidence := float64(spaceCount) / float64(total) + return &Convention{ + Name: "indentation", + Description: "Spaces for indentation", + Confidence: confidence, + } +} + +// detectNamingConvention samples identifiers to determine naming style. +func detectNamingConvention(dir string, lang string) *Convention { + // For Go, the convention is well-known: exported = PascalCase, local = camelCase. + if lang == "Go" { + return &Convention{ + Name: "naming", + Description: "camelCase/PascalCase (Go standard)", + Confidence: 1.0, + } + } + + // For Python, sample for snake_case vs camelCase. + if lang == "Python" { + snakeCount := 0 + camelCount := 0 + sampled := 0 + + snakeRe := regexp.MustCompile(`\bdef ([a-z][a-z0-9]*_[a-z0-9_]+)\b`) + camelRe := regexp.MustCompile(`\bdef ([a-z][a-zA-Z0-9]+[A-Z][a-zA-Z0-9]*)\b`) + + _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { + if err != nil || sampled >= 10 { + return filepath.SkipAll + } + if d.IsDir() && skipDirs[d.Name()] { + return filepath.SkipDir + } + if d.IsDir() || filepath.Ext(path) != ".py" { + return nil + } + + data, err := os.ReadFile(path) + if err != nil { + return nil + } + content := string(data) + snakeCount += len(snakeRe.FindAllString(content, -1)) + camelCount += len(camelRe.FindAllString(content, -1)) + sampled++ + return nil + }) + + total := snakeCount + camelCount + if total == 0 { + return nil + } + if snakeCount > camelCount { + return &Convention{ + Name: "naming", + Description: "snake_case (Python standard)", + Confidence: float64(snakeCount) / float64(total), + } + } + return &Convention{ + Name: "naming", + Description: "camelCase", + Confidence: float64(camelCount) / float64(total), + } + } + + return nil +} + +// detectGoErrorHandling checks error handling patterns in Go source files. +func detectGoErrorHandling(dir string) *Convention { + wrapCount := 0 // fmt.Errorf("...: %w", err) + bareCount := 0 // return err (without wrapping) + sampled := 0 + + wrapRe := regexp.MustCompile(`fmt\.Errorf\([^)]*%w`) + bareRe := regexp.MustCompile(`return\s+err\b`) + + _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { + if err != nil || sampled >= 20 { + return filepath.SkipAll + } + if d.IsDir() && skipDirs[d.Name()] { + return filepath.SkipDir + } + if d.IsDir() || !strings.HasSuffix(d.Name(), ".go") || strings.HasSuffix(d.Name(), "_test.go") { + return nil + } + + data, err := os.ReadFile(path) + if err != nil { + return nil + } + content := string(data) + wrapCount += len(wrapRe.FindAllString(content, -1)) + bareCount += len(bareRe.FindAllString(content, -1)) + sampled++ + return nil + }) + + total := wrapCount + bareCount + if total == 0 { + return nil + } + + if wrapCount > bareCount { + return &Convention{ + Name: "error-handling", + Description: "Error wrapping with %w", + Confidence: float64(wrapCount) / float64(total), + } + } + return &Convention{ + Name: "error-handling", + Description: "Bare error returns", + Confidence: float64(bareCount) / float64(total), + } +} + +// detectImportOrganization checks if imports are grouped (stdlib vs third-party). +func detectImportOrganization(dir string, lang string) *Convention { + if lang != "Go" { + return nil + } + + groupedCount := 0 + ungroupedCount := 0 + sampled := 0 + + _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { + if err != nil || sampled >= 15 { + return filepath.SkipAll + } + if d.IsDir() && skipDirs[d.Name()] { + return filepath.SkipDir + } + if d.IsDir() || !strings.HasSuffix(d.Name(), ".go") || strings.HasSuffix(d.Name(), "_test.go") { + return nil + } + + data, err := os.ReadFile(path) + if err != nil { + return nil + } + content := string(data) + + // Find import blocks. + importStart := strings.Index(content, "import (") + if importStart == -1 { + return nil + } + importEnd := strings.Index(content[importStart:], ")") + if importEnd == -1 { + return nil + } + importBlock := content[importStart : importStart+importEnd] + + // Check for blank lines within the import block (indicating grouping). + if strings.Contains(importBlock, "\n\n") { + groupedCount++ + } else { + // Only count as ungrouped if there are multiple imports. + lines := strings.Split(importBlock, "\n") + importLines := 0 + for _, l := range lines { + l = strings.TrimSpace(l) + if l != "" && l != "import (" && l != ")" && !strings.HasPrefix(l, "//") { + importLines++ + } + } + if importLines > 1 { + ungroupedCount++ + } + } + sampled++ + return nil + }) + + total := groupedCount + ungroupedCount + if total == 0 { + return nil + } + + if groupedCount > ungroupedCount { + return &Convention{ + Name: "imports", + Description: "Grouped imports (stdlib separated from third-party)", + Confidence: float64(groupedCount) / float64(total), + } + } + return &Convention{ + Name: "imports", + Description: "Ungrouped imports", + Confidence: float64(ungroupedCount) / float64(total), + } +} + +// detectTestNaming checks test naming conventions. +func detectTestNaming(dir string, lang string) *Convention { + if lang != "Go" { + return nil + } + + // Check for table-driven tests vs simple tests. + tableDrivenCount := 0 + simpleCount := 0 + sampled := 0 + + tableDrivenRe := regexp.MustCompile(`(tests|cases|testCases|tt)\s*:?=\s*\[\]struct`) + simpleFuncRe := regexp.MustCompile(`func Test[A-Z]\w+\(t \*testing\.T\)`) + + _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { + if err != nil || sampled >= 15 { + return filepath.SkipAll + } + if d.IsDir() && skipDirs[d.Name()] { + return filepath.SkipDir + } + if d.IsDir() || !strings.HasSuffix(d.Name(), "_test.go") { + return nil + } + + data, err := os.ReadFile(path) + if err != nil { + return nil + } + content := string(data) + tableDrivenCount += len(tableDrivenRe.FindAllString(content, -1)) + simpleCount += len(simpleFuncRe.FindAllString(content, -1)) + sampled++ + return nil + }) + + if tableDrivenCount > 0 && simpleCount > 0 { + total := tableDrivenCount + simpleCount + if tableDrivenCount > simpleCount/2 { + return &Convention{ + Name: "test-style", + Description: "Table-driven tests", + Confidence: float64(tableDrivenCount) / float64(total), + } + } + } + + return nil +} + +// detectCommitStyle checks git log for conventional commits or other patterns. +func detectCommitStyle(dir string) *Convention { + cmd := exec.CommandContext(context.Background(), "git", "log", "--oneline", "-20", "--format=%s") + cmd.Dir = dir + out, err := cmd.Output() + if err != nil { + return nil + } + + lines := strings.Split(strings.TrimSpace(string(out)), "\n") + if len(lines) == 0 { + return nil + } + + // Check for conventional commits (feat:, fix:, chore:, etc.). + conventionalRe := regexp.MustCompile(`^(feat|fix|chore|docs|style|refactor|perf|test|build|ci|revert)(\(.+\))?:`) + conventionalCount := 0 + + for _, line := range lines { + if conventionalRe.MatchString(line) { + conventionalCount++ + } + } + + if conventionalCount > 0 { + confidence := float64(conventionalCount) / float64(len(lines)) + if confidence >= 0.3 { + return &Convention{ + Name: "commit-style", + Description: "Conventional commits (feat:, fix:, etc.)", + Confidence: confidence, + } + } + } + + return nil +} diff --git a/internal/feature/fingerprint/project_detect.go b/internal/feature/fingerprint/project_detect.go new file mode 100644 index 00000000..56cc1eb8 --- /dev/null +++ b/internal/feature/fingerprint/project_detect.go @@ -0,0 +1,718 @@ +package fingerprint + +import ( + "encoding/json" + "io/fs" + "os" + "path/filepath" + "sort" + "strings" +) + +// This file holds the language/framework/build/CI/etc. detectors used by Scan. +// Convention detection lives in project_conventions.go; the Scan orchestration, +// types, recommendations, and summary formatting live in project.go. + +// detectLanguages walks the project directory, maps extensions to languages, +// calculates percentages, and returns results sorted by file count descending. +func detectLanguages(dir string) []ProjectLangInfo { + counts := make(map[string]int) + + _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return nil + } + if d.IsDir() { + if skipDirs[d.Name()] { + return filepath.SkipDir + } + return nil + } + if !d.Type().IsRegular() { + return nil + } + + ext := filepath.Ext(path) + if lang, ok := extToLang[ext]; ok { + counts[lang]++ + } + return nil + }) + + if len(counts) == 0 { + return nil + } + + total := 0 + for _, c := range counts { + total += c + } + + langs := make([]ProjectLangInfo, 0, len(counts)) + for name, count := range counts { + pct := float64(count) / float64(total) * 100 + langs = append(langs, ProjectLangInfo{ + Name: name, + FileCount: count, + Percentage: pct, + }) + } + + sort.Slice(langs, func(i, j int) bool { + return langs[i].FileCount > langs[j].FileCount + }) + + return langs +} + +// detectFramework reads key config files to determine the web/app framework. +func detectFramework(dir string, primaryLang string) string { + switch primaryLang { + case "Go": + return detectGoFramework(dir) + case "Python": + return detectPythonFramework(dir) + case "JavaScript", "TypeScript": + return detectJSFramework(dir) + case "Rust": + return detectRustFramework(dir) + } + return "" +} + +// detectGoFramework reads go.mod for known Go web frameworks. +func detectGoFramework(dir string) string { + goModPath := filepath.Join(dir, "go.mod") + data, err := os.ReadFile(goModPath) + if err != nil { + return "" + } + content := string(data) + + frameworks := []struct { + module string + name string + }{ + {"github.com/go-chi/chi", "chi"}, + {"github.com/gin-gonic/gin", "gin"}, + {"github.com/labstack/echo", "echo"}, + {"github.com/gofiber/fiber", "fiber"}, + {"github.com/gorilla/mux", "gorilla"}, + {"github.com/julienschmidt/httprouter", "httprouter"}, + {"github.com/valyala/fasthttp", "fasthttp"}, + } + + for _, fw := range frameworks { + if strings.Contains(content, fw.module) { + return fw.name + } + } + + // Check for net/http usage (fallback — it's in stdlib so not in go.mod). + return "" +} + +// detectPythonFramework reads requirements.txt, Pipfile, or pyproject.toml. +func detectPythonFramework(dir string) string { + files := []string{ + filepath.Join(dir, "requirements.txt"), + filepath.Join(dir, "Pipfile"), + filepath.Join(dir, "pyproject.toml"), + filepath.Join(dir, "setup.py"), + } + + frameworks := []struct { + keyword string + name string + }{ + {"django", "django"}, + {"Django", "django"}, + {"flask", "flask"}, + {"Flask", "flask"}, + {"fastapi", "fastapi"}, + {"FastAPI", "fastapi"}, + {"tornado", "tornado"}, + {"starlette", "starlette"}, + {"sanic", "sanic"}, + } + + for _, f := range files { + data, err := os.ReadFile(f) + if err != nil { + continue + } + content := string(data) + for _, fw := range frameworks { + if strings.Contains(content, fw.keyword) { + return fw.name + } + } + } + + return "" +} + +// detectJSFramework reads package.json for known JS/TS frameworks. +func detectJSFramework(dir string) string { + pkgPath := filepath.Join(dir, "package.json") + data, err := os.ReadFile(pkgPath) + if err != nil { + return "" + } + + var pkg struct { + Dependencies map[string]interface{} `json:"dependencies"` + DevDependencies map[string]interface{} `json:"devDependencies"` + } + if err := json.Unmarshal(data, &pkg); err != nil { + return "" + } + + // Merge deps for lookup. + allDeps := make(map[string]bool) + for k := range pkg.Dependencies { + allDeps[k] = true + } + for k := range pkg.DevDependencies { + allDeps[k] = true + } + + frameworks := []struct { + pkg string + name string + }{ + {"next", "next.js"}, + {"nuxt", "nuxt"}, + {"@angular/core", "angular"}, + {"vue", "vue"}, + {"svelte", "svelte"}, + {"express", "express"}, + {"fastify", "fastify"}, + {"koa", "koa"}, + {"hapi", "hapi"}, + {"react", "react"}, + {"gatsby", "gatsby"}, + {"remix", "remix"}, + } + + for _, fw := range frameworks { + if allDeps[fw.pkg] { + return fw.name + } + } + + return "" +} + +// detectRustFramework reads Cargo.toml for known Rust web frameworks. +func detectRustFramework(dir string) string { + cargoPath := filepath.Join(dir, "Cargo.toml") + data, err := os.ReadFile(cargoPath) + if err != nil { + return "" + } + content := string(data) + + frameworks := []struct { + crate string + name string + }{ + {"actix-web", "actix"}, + {"rocket", "rocket"}, + {"axum", "axum"}, + {"warp", "warp"}, + {"tide", "tide"}, + } + + for _, fw := range frameworks { + if strings.Contains(content, fw.crate) { + return fw.name + } + } + + return "" +} + +// detectBuildSystem determines the project's build system from manifest files. +func detectBuildSystem(dir string) string { + buildSystems := []struct { + file string + system string + }{ + {"go.mod", "go modules"}, + {"package.json", "npm"}, + {"Cargo.toml", "cargo"}, + {"pom.xml", "maven"}, + {"build.gradle", "gradle"}, + {"build.gradle.kts", "gradle"}, + {"CMakeLists.txt", "cmake"}, + {"Makefile", "make"}, + {"meson.build", "meson"}, + {"BUILD", "bazel"}, + {"WORKSPACE", "bazel"}, + {"mix.exs", "mix"}, + {"pubspec.yaml", "pub"}, + {"Package.swift", "swift package manager"}, + {"Rakefile", "rake"}, + } + + for _, bs := range buildSystems { + path := filepath.Join(dir, bs.file) + if _, err := os.Stat(path); err == nil { + return bs.system + } + } + + return "" +} + +// detectProjectPackageManager determines the project's package manager. +func detectProjectPackageManager(dir string) string { + managers := []struct { + file string + manager string + }{ + {"pnpm-lock.yaml", "pnpm"}, + {"yarn.lock", "yarn"}, + {"package-lock.json", "npm"}, + {"bun.lockb", "bun"}, + {"go.sum", "go modules"}, + {"Cargo.lock", "cargo"}, + {"Pipfile.lock", "pipenv"}, + {"poetry.lock", "poetry"}, + {"Gemfile.lock", "bundler"}, + {"composer.lock", "composer"}, + {"pubspec.lock", "pub"}, + {"mix.lock", "mix"}, + } + + for _, m := range managers { + path := filepath.Join(dir, m.file) + if _, err := os.Stat(path); err == nil { + return m.manager + } + } + + // Fallback: check manifest files. + fallbacks := []struct { + file string + manager string + }{ + {"go.mod", "go modules"}, + {"package.json", "npm"}, + {"Cargo.toml", "cargo"}, + {"requirements.txt", "pip"}, + {"pyproject.toml", "pip"}, + {"Gemfile", "bundler"}, + {"composer.json", "composer"}, + } + + for _, f := range fallbacks { + path := filepath.Join(dir, f.file) + if _, err := os.Stat(path); err == nil { + return f.manager + } + } + + return "" +} + +// detectTestFramework determines the test framework used. +func detectTestFramework(dir string, lang string) string { + switch lang { + case "Go": + // Go has a built-in test framework. + // Check for testify or other test libs in go.mod. + goModPath := filepath.Join(dir, "go.mod") + if data, err := os.ReadFile(goModPath); err == nil { + content := string(data) + if strings.Contains(content, "github.com/stretchr/testify") { + return "go test + testify" + } + } + // Check if there are any _test.go files. + hasTests := false + _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { + if err != nil || hasTests { + return filepath.SkipAll + } + if d.IsDir() && skipDirs[d.Name()] { + return filepath.SkipDir + } + if !d.IsDir() && strings.HasSuffix(d.Name(), "_test.go") { + hasTests = true + return filepath.SkipAll + } + return nil + }) + if hasTests { + return "go test" + } + return "" + + case "JavaScript", "TypeScript": + pkgPath := filepath.Join(dir, "package.json") + data, err := os.ReadFile(pkgPath) + if err != nil { + return "" + } + var pkg struct { + Dependencies map[string]interface{} `json:"dependencies"` + DevDependencies map[string]interface{} `json:"devDependencies"` + Scripts map[string]string `json:"scripts"` + } + if err := json.Unmarshal(data, &pkg); err != nil { + return "" + } + + allDeps := make(map[string]bool) + for k := range pkg.Dependencies { + allDeps[k] = true + } + for k := range pkg.DevDependencies { + allDeps[k] = true + } + + if allDeps["vitest"] { + return "vitest" + } + if allDeps["jest"] || allDeps["@jest/core"] || allDeps["ts-jest"] { + return "jest" + } + if allDeps["mocha"] { + return "mocha" + } + if allDeps["ava"] { + return "ava" + } + if allDeps["cypress"] { + return "cypress" + } + if allDeps["playwright"] || allDeps["@playwright/test"] { + return "playwright" + } + return "" + + case "Python": + // Check for pytest in requirements or installed. + files := []string{ + filepath.Join(dir, "requirements.txt"), + filepath.Join(dir, "requirements-dev.txt"), + filepath.Join(dir, "pyproject.toml"), + filepath.Join(dir, "setup.cfg"), + filepath.Join(dir, "Pipfile"), + } + for _, f := range files { + data, err := os.ReadFile(f) + if err != nil { + continue + } + content := string(data) + if strings.Contains(content, "pytest") { + return "pytest" + } + } + // Check for pytest.ini or conftest.py. + if _, err := os.Stat(filepath.Join(dir, "pytest.ini")); err == nil { + return "pytest" + } + if _, err := os.Stat(filepath.Join(dir, "conftest.py")); err == nil { + return "pytest" + } + // Check for test files with unittest patterns. + hasUnittest := false + _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { + if err != nil || hasUnittest { + return filepath.SkipAll + } + if d.IsDir() && skipDirs[d.Name()] { + return filepath.SkipDir + } + if !d.IsDir() && (strings.HasPrefix(d.Name(), "test_") || strings.HasSuffix(d.Name(), "_test.py")) { + hasUnittest = true + return filepath.SkipAll + } + return nil + }) + if hasUnittest { + return "unittest" + } + return "" + + case "Rust": + // Rust uses built-in test framework via cargo test. + cargoPath := filepath.Join(dir, "Cargo.toml") + if _, err := os.Stat(cargoPath); err == nil { + return "cargo test" + } + return "" + + case "Java": + pomPath := filepath.Join(dir, "pom.xml") + if data, err := os.ReadFile(pomPath); err == nil { + content := string(data) + if strings.Contains(content, "junit") || strings.Contains(content, "JUnit") { + return "junit" + } + if strings.Contains(content, "testng") { + return "testng" + } + } + gradlePath := filepath.Join(dir, "build.gradle") + if data, err := os.ReadFile(gradlePath); err == nil { + content := string(data) + if strings.Contains(content, "junit") || strings.Contains(content, "JUnit") { + return "junit" + } + if strings.Contains(content, "testng") { + return "testng" + } + } + return "" + + case "Ruby": + if _, err := os.Stat(filepath.Join(dir, "spec")); err == nil { + return "rspec" + } + gemPath := filepath.Join(dir, "Gemfile") + if data, err := os.ReadFile(gemPath); err == nil { + content := string(data) + if strings.Contains(content, "rspec") { + return "rspec" + } + if strings.Contains(content, "minitest") { + return "minitest" + } + } + return "" + } + + return "" +} + +// detectLintTools detects configured linting tools from config files. +func detectLintTools(dir string) []string { + var tools []string + + lintConfigs := []struct { + file string + tool string + }{ + {".golangci.yml", "golangci-lint"}, + {".golangci.yaml", "golangci-lint"}, + {".golangci.toml", "golangci-lint"}, + {".eslintrc", "eslint"}, + {".eslintrc.js", "eslint"}, + {".eslintrc.json", "eslint"}, + {".eslintrc.yml", "eslint"}, + {"eslint.config.js", "eslint"}, + {"eslint.config.mjs", "eslint"}, + {".prettierrc", "prettier"}, + {".prettierrc.js", "prettier"}, + {".prettierrc.json", "prettier"}, + {"prettier.config.js", "prettier"}, + {".stylelintrc", "stylelint"}, + {".stylelintrc.json", "stylelint"}, + {"stylelint.config.js", "stylelint"}, + {".flake8", "flake8"}, + {"setup.cfg", "flake8"}, // might contain flake8 config + {".pylintrc", "pylint"}, + {"pyproject.toml", "ruff"}, // might contain ruff config + {".rubocop.yml", "rubocop"}, + {"clippy.toml", "clippy"}, + {".clippy.toml", "clippy"}, + {".editorconfig", "editorconfig"}, + {"biome.json", "biome"}, + {"deno.json", "deno lint"}, + {".hadolint.yaml", "hadolint"}, + {".shellcheckrc", "shellcheck"}, + } + + seen := make(map[string]bool) + for _, lc := range lintConfigs { + path := filepath.Join(dir, lc.file) + if _, err := os.Stat(path); err == nil { + // Special case: setup.cfg / pyproject.toml may or may not contain lint config. + if lc.file == "setup.cfg" { + if data, err := os.ReadFile(path); err == nil { + if !strings.Contains(string(data), "[flake8]") { + continue + } + } + } + if lc.file == "pyproject.toml" { + if data, err := os.ReadFile(path); err == nil { + if !strings.Contains(string(data), "[tool.ruff]") && !strings.Contains(string(data), "ruff") { + continue + } + } + } + if !seen[lc.tool] { + seen[lc.tool] = true + tools = append(tools, lc.tool) + } + } + } + + // Check package.json for lint-related devDependencies. + pkgPath := filepath.Join(dir, "package.json") + if data, err := os.ReadFile(pkgPath); err == nil { + var pkg struct { + DevDependencies map[string]interface{} `json:"devDependencies"` + } + if err := json.Unmarshal(data, &pkg); err == nil { + jsDeps := []struct { + pkg string + tool string + }{ + {"eslint", "eslint"}, + {"prettier", "prettier"}, + {"stylelint", "stylelint"}, + {"biome", "biome"}, + {"@biomejs/biome", "biome"}, + } + for _, jd := range jsDeps { + if _, ok := pkg.DevDependencies[jd.pkg]; ok && !seen[jd.tool] { + seen[jd.tool] = true + tools = append(tools, jd.tool) + } + } + } + } + + return tools +} + +// detectCISystem identifies the CI/CD system in use and returns its name. +func detectCISystem(dir string) string { + ciSystems := []struct { + path string + isDir bool + system string + }{ + {filepath.Join(".github", "workflows"), true, "github-actions"}, + {".gitlab-ci.yml", false, "gitlab-ci"}, + {".circleci", true, "circleci"}, + {"Jenkinsfile", false, "jenkins"}, + {".travis.yml", false, "travis-ci"}, + {"azure-pipelines.yml", false, "azure-pipelines"}, + {"bitbucket-pipelines.yml", false, "bitbucket-pipelines"}, + {".drone.yml", false, "drone"}, + {".buildkite", true, "buildkite"}, + {"cloudbuild.yaml", false, "google-cloud-build"}, + {"cloudbuild.yml", false, "google-cloud-build"}, + {".tekton", true, "tekton"}, + } + + for _, ci := range ciSystems { + full := filepath.Join(dir, ci.path) + info, err := os.Stat(full) + if err != nil { + continue + } + if ci.isDir && info.IsDir() { + return ci.system + } + if !ci.isDir && !info.IsDir() { + return ci.system + } + } + + return "" +} + +// detectDocker checks for Dockerfile or docker-compose files. +func detectDocker(dir string) bool { + dockerFiles := []string{ + "Dockerfile", + "dockerfile", + "docker-compose.yml", + "docker-compose.yaml", + "compose.yml", + "compose.yaml", + ".dockerignore", + } + + for _, f := range dockerFiles { + path := filepath.Join(dir, f) + if _, err := os.Stat(path); err == nil { + return true + } + } + + return false +} + +// detectMonorepo checks for indicators of a monorepo structure. +func detectMonorepo(dir string) bool { + // Check for multiple go.mod files. + goModCount := 0 + _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return nil + } + if d.IsDir() && skipDirs[d.Name()] { + return filepath.SkipDir + } + if !d.IsDir() && d.Name() == "go.mod" { + goModCount++ + if goModCount > 1 { + return filepath.SkipAll + } + } + return nil + }) + if goModCount > 1 { + return true + } + + // Check for go.work file. + if _, err := os.Stat(filepath.Join(dir, "go.work")); err == nil { + return true + } + + // Check for packages/ or apps/ directory (JS monorepos). + monoDirs := []string{"packages", "apps", "modules", "services", "libs"} + for _, md := range monoDirs { + path := filepath.Join(dir, md) + if info, err := os.Stat(path); err == nil && info.IsDir() { + return true + } + } + + // Check for lerna.json, pnpm-workspace.yaml, or turbo.json. + monoFiles := []string{"lerna.json", "pnpm-workspace.yaml", "turbo.json", "nx.json"} + for _, mf := range monoFiles { + path := filepath.Join(dir, mf) + if _, err := os.Stat(path); err == nil { + return true + } + } + + // Check package.json for workspaces field. + pkgPath := filepath.Join(dir, "package.json") + if data, err := os.ReadFile(pkgPath); err == nil { + var pkg map[string]interface{} + if err := json.Unmarshal(data, &pkg); err == nil { + if _, ok := pkg["workspaces"]; ok { + return true + } + } + } + + return false +} + +// classifyProjectSize returns a size classification based on file count. +func classifyProjectSize(fileCount int) string { + switch { + case fileCount < 10: + return "tiny" + case fileCount < 100: + return "small" + case fileCount <= 1000: + return "medium" + default: + return "large" + } +} From a8d402b80d8ae88b931ffb2ad15f2e0fbcc9ca18 Mon Sep 17 00:00:00 2001 From: Lakshman Patel Date: Fri, 19 Jun 2026 18:58:56 +0530 Subject: [PATCH 05/20] refactor(repomap): split depgraph.go into depgraph_build.go and depgraph_analysis.go --- internal/intelligence/repomap/depgraph.go | 1020 +---------------- .../intelligence/repomap/depgraph_analysis.go | 553 +++++++++ .../intelligence/repomap/depgraph_build.go | 479 ++++++++ 3 files changed, 1036 insertions(+), 1016 deletions(-) create mode 100644 internal/intelligence/repomap/depgraph_analysis.go create mode 100644 internal/intelligence/repomap/depgraph_build.go diff --git a/internal/intelligence/repomap/depgraph.go b/internal/intelligence/repomap/depgraph.go index 790dcf4d..b73da230 100644 --- a/internal/intelligence/repomap/depgraph.go +++ b/internal/intelligence/repomap/depgraph.go @@ -3,17 +3,15 @@ // (via package.json + import/require regexes). It computes topological // order, layers, cycles, hot paths, and renders the result as DOT, // Mermaid, or ASCII art for use in summaries and dashboards. +// +// This file holds the graph type, node mutators, renderers, and stats. The +// builders live in depgraph_build.go; the traversal algorithms live in +// depgraph_analysis.go. package repomap import ( - "bufio" - "encoding/json" "fmt" - "go/parser" - "go/token" - "os" "path/filepath" - "regexp" "sort" "strings" "sync" @@ -92,759 +90,6 @@ func (dg *DepGraph) AddEdge(edge DepEdge) { dg.Edges = append(dg.Edges, edge) } -// BuildFromGoMod reads go.mod and scans .go files to build the dependency graph. -func (dg *DepGraph) BuildFromGoMod(projectDir string) error { - dg.mu.Lock() - defer dg.mu.Unlock() - - goModPath := filepath.Join(projectDir, "go.mod") - modData, err := os.ReadFile(goModPath) - if err != nil { - return fmt.Errorf("depgraph: read go.mod: %w", err) - } - - moduleName := parseModuleName(string(modData)) - if moduleName == "" { - return fmt.Errorf("depgraph: cannot determine module name from go.mod") - } - dg.Root = moduleName - - // Parse external dependencies from go.mod require blocks. - externalDeps := parseGoModRequires(string(modData)) - - // Add external dependency nodes. - for _, dep := range externalDeps { - shortName := filepath.Base(dep) - dg.Nodes[dep] = &DepNode{ - ID: dep, - Name: shortName, - Type: "external", - ImportedBy: []string{}, - Imports: []string{}, - } - } - - // Scan all .go files to collect imports and build internal packages. - internalPkgs := make(map[string]*DepNode) - // pkgImports maps each internal package path to a set of import paths. - pkgImports := make(map[string]map[string]bool) - - fset := token.NewFileSet() - err = filepath.Walk(projectDir, func(path string, info os.FileInfo, err error) error { - if err != nil { - return nil - } - if info.IsDir() { - base := filepath.Base(path) - if base == "vendor" || base == ".git" || base == "testdata" { - return filepath.SkipDir - } - return nil - } - if !strings.HasSuffix(path, ".go") { - return nil - } - // Skip test files for dependency analysis. - if strings.HasSuffix(path, "_test.go") { - return nil - } - - f, parseErr := parser.ParseFile(fset, path, nil, parser.ImportsOnly) - if parseErr != nil { - return nil - } - - relDir, _ := filepath.Rel(projectDir, filepath.Dir(path)) - if relDir == "" || relDir == "." { - relDir = "" - } - var pkgPath string - if relDir == "" { - pkgPath = moduleName - } else { - pkgPath = moduleName + "/" + filepath.ToSlash(relDir) - } - - if _, ok := internalPkgs[pkgPath]; !ok { - shortName := filepath.Base(pkgPath) - if pkgPath == moduleName { - shortName = filepath.Base(moduleName) - } - internalPkgs[pkgPath] = &DepNode{ - ID: pkgPath, - Name: shortName, - Type: "internal", - FileCount: 0, - LOC: 0, - ImportedBy: []string{}, - Imports: []string{}, - } - pkgImports[pkgPath] = make(map[string]bool) - } - - internalPkgs[pkgPath].FileCount++ - - // Count LOC. - loc := countFileLOC(path) - internalPkgs[pkgPath].LOC += loc - - // Collect imports. - for _, imp := range f.Imports { - impPath := strings.Trim(imp.Path.Value, `"`) - pkgImports[pkgPath][impPath] = true - } - - return nil - }) - if err != nil { - return fmt.Errorf("depgraph: walk project: %w", err) - } - - // Add internal package nodes. - for id, node := range internalPkgs { - dg.Nodes[id] = node - } - - // Process imports and create edges. - for pkgPath, imports := range pkgImports { - for imp := range imports { - impType := classifyImport(imp, moduleName, externalDeps) - - // Ensure stdlib nodes exist. - if impType == "stdlib" { - if _, ok := dg.Nodes[imp]; !ok { - dg.Nodes[imp] = &DepNode{ - ID: imp, - Name: filepath.Base(imp), - Type: "stdlib", - ImportedBy: []string{}, - Imports: []string{}, - } - } - } - - // Record the import relationship. - if node, ok := dg.Nodes[pkgPath]; ok { - node.Imports = appendUniqueStr(node.Imports, imp) - } - if node, ok := dg.Nodes[imp]; ok { - node.ImportedBy = appendUniqueStr(node.ImportedBy, pkgPath) - } - - // Add edge. - found := false - for i, e := range dg.Edges { - if e.From == pkgPath && e.To == imp { - dg.Edges[i].Weight++ - found = true - break - } - } - if !found { - dg.Edges = append(dg.Edges, DepEdge{ - From: pkgPath, - To: imp, - Weight: 1, - }) - } - } - } - - return nil -} - -// BuildFromPackageJSON reads package.json and scans JS/TS files to build the -// dependency graph. -func (dg *DepGraph) BuildFromPackageJSON(projectDir string) error { - dg.mu.Lock() - defer dg.mu.Unlock() - - pkgJSONPath := filepath.Join(projectDir, "package.json") - data, err := os.ReadFile(pkgJSONPath) - if err != nil { - return fmt.Errorf("depgraph: read package.json: %w", err) - } - - var pkgJSON struct { - Name string `json:"name"` - Dependencies map[string]string `json:"dependencies"` - DevDependencies map[string]string `json:"devDependencies"` - } - if unmarshalErr := json.Unmarshal(data, &pkgJSON); unmarshalErr != nil { - return fmt.Errorf("depgraph: parse package.json: %w", unmarshalErr) - } - - dg.Root = pkgJSON.Name - - // Add the root package node. - dg.Nodes[pkgJSON.Name] = &DepNode{ - ID: pkgJSON.Name, - Name: pkgJSON.Name, - Type: "internal", - ImportedBy: []string{}, - Imports: []string{}, - } - - // Collect all declared dependencies. - allDeps := make(map[string]bool) - for dep := range pkgJSON.Dependencies { - allDeps[dep] = true - dg.Nodes[dep] = &DepNode{ - ID: dep, - Name: dep, - Type: "external", - ImportedBy: []string{}, - Imports: []string{}, - } - } - for dep := range pkgJSON.DevDependencies { - allDeps[dep] = true - if _, ok := dg.Nodes[dep]; !ok { - dg.Nodes[dep] = &DepNode{ - ID: dep, - Name: dep, - Type: "external", - ImportedBy: []string{}, - Imports: []string{}, - } - } - } - - // Scan JS/TS files for imports. - jsImportRe := regexp.MustCompile(`(?:import\s+.*?\s+from\s+['"]([^'"]+)['"]|require\s*\(\s*['"]([^'"]+)['"]\s*\))`) - - // Internal modules map (relative imports). - internalModules := make(map[string]*DepNode) - - err = filepath.Walk(projectDir, func(path string, info os.FileInfo, err error) error { - if err != nil { - return nil - } - if info.IsDir() { - base := filepath.Base(path) - if base == "node_modules" || base == ".git" || base == "dist" || base == "build" { - return filepath.SkipDir - } - return nil - } - ext := filepath.Ext(path) - if ext != ".js" && ext != ".ts" && ext != ".jsx" && ext != ".tsx" { - return nil - } - - relPath, _ := filepath.Rel(projectDir, path) - relPath = filepath.ToSlash(relPath) - - // Determine the "module" path for this file. - modPath := pkgJSON.Name + "/" + relPath - - if _, ok := internalModules[modPath]; !ok { - internalModules[modPath] = &DepNode{ - ID: modPath, - Name: filepath.Base(relPath), - Type: "internal", - FileCount: 1, - LOC: 0, - ImportedBy: []string{}, - Imports: []string{}, - } - } - internalModules[modPath].LOC += countFileLOC(path) - - // Read file and find imports. - content, readErr := os.ReadFile(path) - if readErr != nil { - return nil - } - - matches := jsImportRe.FindAllStringSubmatch(string(content), -1) - for _, match := range matches { - imported := match[1] - if imported == "" { - imported = match[2] - } - if imported == "" { - continue - } - - var targetID string - if strings.HasPrefix(imported, ".") { - // Relative import => internal. - dir := filepath.Dir(relPath) - resolved := filepath.ToSlash(filepath.Join(dir, imported)) - targetID = pkgJSON.Name + "/" + resolved - if _, ok := internalModules[targetID]; !ok { - internalModules[targetID] = &DepNode{ - ID: targetID, - Name: filepath.Base(imported), - Type: "internal", - ImportedBy: []string{}, - Imports: []string{}, - } - } - } else { - // Package import => external. - // Extract package name (handle scoped packages). - pkgName := imported - if strings.HasPrefix(imported, "@") { - parts := strings.SplitN(imported, "/", 3) - if len(parts) >= 2 { - pkgName = parts[0] + "/" + parts[1] - } - } else { - parts := strings.SplitN(imported, "/", 2) - pkgName = parts[0] - } - targetID = pkgName - if _, ok := dg.Nodes[targetID]; !ok { - nodeType := "external" - if isNodeBuiltin(pkgName) { - nodeType = "stdlib" - } - dg.Nodes[targetID] = &DepNode{ - ID: targetID, - Name: pkgName, - Type: nodeType, - ImportedBy: []string{}, - Imports: []string{}, - } - } - } - - internalModules[modPath].Imports = appendUniqueStr(internalModules[modPath].Imports, targetID) - if node, ok := dg.Nodes[targetID]; ok { - node.ImportedBy = appendUniqueStr(node.ImportedBy, modPath) - } else if mod, ok := internalModules[targetID]; ok { - mod.ImportedBy = appendUniqueStr(mod.ImportedBy, modPath) - } - - // Add edge. - found := false - for i, e := range dg.Edges { - if e.From == modPath && e.To == targetID { - dg.Edges[i].Weight++ - found = true - break - } - } - if !found { - dg.Edges = append(dg.Edges, DepEdge{ - From: modPath, - To: targetID, - Weight: 1, - }) - } - } - - return nil - }) - if err != nil { - return fmt.Errorf("depgraph: walk project: %w", err) - } - - // Merge internal modules into the graph. - for id, node := range internalModules { - dg.Nodes[id] = node - } - - return nil -} - -// TopologicalSort returns packages in dependency order (leaves first). -// Packages with no dependencies appear first. -func (dg *DepGraph) TopologicalSort() []string { - dg.mu.RLock() - defer dg.mu.RUnlock() - - // Build adjacency list and in-degree count. - inDegree := make(map[string]int) - adj := make(map[string][]string) - - for id := range dg.Nodes { - inDegree[id] = 0 - } - for _, edge := range dg.Edges { - if _, ok := dg.Nodes[edge.From]; !ok { - continue - } - if _, ok := dg.Nodes[edge.To]; !ok { - continue - } - adj[edge.From] = append(adj[edge.From], edge.To) - inDegree[edge.To]++ - } - - // Kahn's algorithm. - var queue []string - for id, deg := range inDegree { - if deg == 0 { - queue = append(queue, id) - } - } - sort.Strings(queue) // deterministic order - - for len(queue) > 0 { - sort.Strings(queue) - node := queue[0] - queue = queue[1:] - - for _, neighbor := range adj[node] { - inDegree[neighbor]-- - if inDegree[neighbor] == 0 { - queue = append(queue, neighbor) - } - } - } - // For "leaves first" (packages with no dependencies), we reverse the edge direction. - - // Re-do with reversed edges: nodes that IMPORT nothing come first. - outDegree := make(map[string]int) - revAdj := make(map[string][]string) - for id := range dg.Nodes { - outDegree[id] = 0 - } - for _, edge := range dg.Edges { - if _, ok := dg.Nodes[edge.From]; !ok { - continue - } - if _, ok := dg.Nodes[edge.To]; !ok { - continue - } - outDegree[edge.From]++ - revAdj[edge.To] = append(revAdj[edge.To], edge.From) - } - - // Collect leaves (nodes with no outgoing edges = no imports). - var leafQueue []string - for id, deg := range outDegree { - if deg == 0 { - leafQueue = append(leafQueue, id) - } - } - sort.Strings(leafQueue) - - visited := make(map[string]bool) - var sorted []string - for len(leafQueue) > 0 { - sort.Strings(leafQueue) - node := leafQueue[0] - leafQueue = leafQueue[1:] - if visited[node] { - continue - } - visited[node] = true - sorted = append(sorted, node) - - for _, parent := range revAdj[node] { - // Check if all of parent's dependencies are visited. - allVisited := true - for _, e := range dg.Edges { - if e.From == parent { - if _, ok := dg.Nodes[e.To]; ok { - if !visited[e.To] { - allVisited = false - break - } - } - } - } - if allVisited && !visited[parent] { - leafQueue = append(leafQueue, parent) - } - } - } - - // Add any remaining nodes (part of cycles) at the end. - for id := range dg.Nodes { - if !visited[id] { - sorted = append(sorted, id) - } - } - - return sorted -} - -// FindCycles detects circular dependencies and returns all cycles found. -func (dg *DepGraph) FindCycles() [][]string { - dg.mu.RLock() - defer dg.mu.RUnlock() - - // Build adjacency list. - adj := make(map[string][]string) - for _, edge := range dg.Edges { - if _, ok := dg.Nodes[edge.From]; !ok { - continue - } - if _, ok := dg.Nodes[edge.To]; !ok { - continue - } - adj[edge.From] = append(adj[edge.From], edge.To) - } - - // Sort adjacency lists for determinism. - for k := range adj { - sort.Strings(adj[k]) - } - - // Johnson's algorithm simplified: DFS-based cycle detection. - var cycles [][]string - visited := make(map[string]bool) - onStack := make(map[string]bool) - var path []string - - var dfs func(node string) - dfs = func(node string) { - visited[node] = true - onStack[node] = true - path = append(path, node) - - for _, next := range adj[node] { - if !visited[next] { - dfs(next) - } else if onStack[next] { - // Found a cycle: extract it from path. - cycleStart := -1 - for i, p := range path { - if p == next { - cycleStart = i - break - } - } - if cycleStart >= 0 { - cycle := make([]string, len(path)-cycleStart) - copy(cycle, path[cycleStart:]) - // Check for duplicates. - if !containsCycle(cycles, cycle) { - cycles = append(cycles, cycle) - } - } - } - } - - path = path[:len(path)-1] - onStack[node] = false - } - - // Sort nodes for deterministic order. - nodeIDs := make([]string, 0, len(dg.Nodes)) - for id := range dg.Nodes { - nodeIDs = append(nodeIDs, id) - } - sort.Strings(nodeIDs) - - for _, id := range nodeIDs { - if !visited[id] { - dfs(id) - } - } - - return cycles -} - -// Layers groups packages into layers based on dependency depth. -// Layer 0 contains packages with no dependencies, layer 1 depends only on layer 0, etc. -func (dg *DepGraph) Layers() [][]string { - dg.mu.RLock() - defer dg.mu.RUnlock() - - if len(dg.Nodes) == 0 { - return nil - } - - // Build adjacency (from -> deps). - deps := make(map[string][]string) - for _, edge := range dg.Edges { - if _, ok := dg.Nodes[edge.From]; !ok { - continue - } - if _, ok := dg.Nodes[edge.To]; !ok { - continue - } - deps[edge.From] = append(deps[edge.From], edge.To) - } - - layerOf := make(map[string]int) - var computeLayer func(id string, visiting map[string]bool) int - computeLayer = func(id string, visiting map[string]bool) int { - if l, ok := layerOf[id]; ok { - return l - } - if visiting[id] { - // Cycle detected, assign 0 to break it. - return 0 - } - visiting[id] = true - - maxDep := -1 - for _, dep := range deps[id] { - l := computeLayer(dep, visiting) - if l > maxDep { - maxDep = l - } - } - layerOf[id] = maxDep + 1 - delete(visiting, id) - return maxDep + 1 - } - - for id := range dg.Nodes { - if _, ok := layerOf[id]; !ok { - computeLayer(id, make(map[string]bool)) - } - } - - // Group by layer. - maxLayer := 0 - for _, l := range layerOf { - if l > maxLayer { - maxLayer = l - } - } - - layers := make([][]string, maxLayer+1) - for id, l := range layerOf { - layers[l] = append(layers[l], id) - } - - // Sort within each layer for determinism. - for i := range layers { - sort.Strings(layers[i]) - } - - return layers -} - -// HotPaths finds the most-depended-on paths using a PageRank-like importance scoring. -// Returns paths sorted by importance (most critical first). -func (dg *DepGraph) HotPaths() [][]string { - dg.mu.RLock() - defer dg.mu.RUnlock() - - if len(dg.Nodes) == 0 { - return nil - } - - // Compute PageRank-like scores. - scores := make(map[string]float64) - n := float64(len(dg.Nodes)) - for id := range dg.Nodes { - scores[id] = 1.0 / n - } - - // Build inbound edges. - inbound := make(map[string][]string) - outCount := make(map[string]int) - for _, edge := range dg.Edges { - if _, ok := dg.Nodes[edge.From]; !ok { - continue - } - if _, ok := dg.Nodes[edge.To]; !ok { - continue - } - inbound[edge.To] = append(inbound[edge.To], edge.From) - outCount[edge.From]++ - } - - damping := 0.85 - for iter := 0; iter < 20; iter++ { - newScores := make(map[string]float64) - for id := range dg.Nodes { - sum := 0.0 - for _, src := range inbound[id] { - if outCount[src] > 0 { - sum += scores[src] / float64(outCount[src]) - } - } - newScores[id] = (1.0-damping)/n + damping*sum - } - scores = newScores - } - - // Sort nodes by score descending. - type scoredNode struct { - id string - score float64 - } - var ranked []scoredNode - for id, score := range scores { - ranked = append(ranked, scoredNode{id, score}) - } - sort.Slice(ranked, func(i, j int) bool { - if ranked[i].score == ranked[j].score { - return ranked[i].id < ranked[j].id - } - return ranked[i].score > ranked[j].score - }) - - // Build paths from top-ranked nodes following their heaviest dependency chains. - adj := make(map[string][]string) - edgeWeight := make(map[string]int) // "from->to" => weight - for _, edge := range dg.Edges { - if _, ok := dg.Nodes[edge.From]; !ok { - continue - } - if _, ok := dg.Nodes[edge.To]; !ok { - continue - } - adj[edge.From] = append(adj[edge.From], edge.To) - key := edge.From + "->" + edge.To - edgeWeight[key] = edge.Weight - } - - var paths [][]string - used := make(map[string]bool) - - for _, rn := range ranked { - if used[rn.id] { - continue - } - if len(paths) >= 5 { - break - } - - // Follow the heaviest outgoing chain. - path := []string{rn.id} - current := rn.id - visited := map[string]bool{current: true} - - for { - neighbors := adj[current] - if len(neighbors) == 0 { - break - } - // Pick the heaviest edge. - bestNeighbor := "" - bestWeight := 0 - for _, nb := range neighbors { - if visited[nb] { - continue - } - key := current + "->" + nb - w := edgeWeight[key] - if w > bestWeight || (w == bestWeight && nb < bestNeighbor) { - bestWeight = w - bestNeighbor = nb - } - } - if bestNeighbor == "" { - break - } - path = append(path, bestNeighbor) - visited[bestNeighbor] = true - current = bestNeighbor - } - - if len(path) > 1 { - paths = append(paths, path) - for _, p := range path { - used[p] = true - } - } - } - - return paths -} - // RenderDOT generates a Graphviz DOT format representation of the graph. func (dg *DepGraph) RenderDOT() string { dg.mu.RLock() @@ -1063,139 +308,6 @@ func (dg *DepGraph) Stats() GraphStats { return stats } -// --- Internal helpers --- - -// layersUnlocked computes layers without holding the lock (caller must hold RLock). -func (dg *DepGraph) layersUnlocked() [][]string { - if len(dg.Nodes) == 0 { - return nil - } - - deps := make(map[string][]string) - for _, edge := range dg.Edges { - if _, ok := dg.Nodes[edge.From]; !ok { - continue - } - if _, ok := dg.Nodes[edge.To]; !ok { - continue - } - deps[edge.From] = append(deps[edge.From], edge.To) - } - - layerOf := make(map[string]int) - var computeLayer func(id string, visiting map[string]bool) int - computeLayer = func(id string, visiting map[string]bool) int { - if l, ok := layerOf[id]; ok { - return l - } - if visiting[id] { - return 0 - } - visiting[id] = true - - maxDep := -1 - for _, dep := range deps[id] { - l := computeLayer(dep, visiting) - if l > maxDep { - maxDep = l - } - } - layerOf[id] = maxDep + 1 - delete(visiting, id) - return maxDep + 1 - } - - for id := range dg.Nodes { - if _, ok := layerOf[id]; !ok { - computeLayer(id, make(map[string]bool)) - } - } - - maxLayer := 0 - for _, l := range layerOf { - if l > maxLayer { - maxLayer = l - } - } - - layers := make([][]string, maxLayer+1) - for id, l := range layerOf { - layers[l] = append(layers[l], id) - } - for i := range layers { - sort.Strings(layers[i]) - } - - return layers -} - -// findCyclesUnlocked detects cycles without holding the lock (caller must hold RLock). -func (dg *DepGraph) findCyclesUnlocked() [][]string { - adj := make(map[string][]string) - for _, edge := range dg.Edges { - if _, ok := dg.Nodes[edge.From]; !ok { - continue - } - if _, ok := dg.Nodes[edge.To]; !ok { - continue - } - adj[edge.From] = append(adj[edge.From], edge.To) - } - for k := range adj { - sort.Strings(adj[k]) - } - - var cycles [][]string - visited := make(map[string]bool) - onStack := make(map[string]bool) - var path []string - - var dfs func(node string) - dfs = func(node string) { - visited[node] = true - onStack[node] = true - path = append(path, node) - - for _, next := range adj[node] { - if !visited[next] { - dfs(next) - } else if onStack[next] { - cycleStart := -1 - for i, p := range path { - if p == next { - cycleStart = i - break - } - } - if cycleStart >= 0 { - cycle := make([]string, len(path)-cycleStart) - copy(cycle, path[cycleStart:]) - if !containsCycle(cycles, cycle) { - cycles = append(cycles, cycle) - } - } - } - } - - path = path[:len(path)-1] - onStack[node] = false - } - - nodeIDs := make([]string, 0, len(dg.Nodes)) - for id := range dg.Nodes { - nodeIDs = append(nodeIDs, id) - } - sort.Strings(nodeIDs) - - for _, id := range nodeIDs { - if !visited[id] { - dfs(id) - } - } - - return cycles -} - // shortName returns the short display name for a node. func (dg *DepGraph) shortName(id string) string { if node, ok := dg.Nodes[id]; ok && node.Name != "" { @@ -1203,127 +315,3 @@ func (dg *DepGraph) shortName(id string) string { } return filepath.Base(id) } - -// parseModuleName extracts the module name from go.mod content. -func parseModuleName(content string) string { - scanner := bufio.NewScanner(strings.NewReader(content)) - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if strings.HasPrefix(line, "module ") { - return strings.TrimSpace(strings.TrimPrefix(line, "module")) - } - } - return "" -} - -// parseGoModRequires extracts dependency paths from go.mod require blocks. -func parseGoModRequires(content string) []string { - var deps []string - inRequire := false - scanner := bufio.NewScanner(strings.NewReader(content)) - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if strings.HasPrefix(line, "require (") || line == "require (" { - inRequire = true - continue - } - if inRequire { - if line == ")" { - inRequire = false - continue - } - parts := strings.Fields(line) - if len(parts) >= 2 { - deps = append(deps, parts[0]) - } - } - // Single-line require. - if strings.HasPrefix(line, "require ") && !strings.Contains(line, "(") { - parts := strings.Fields(line) - if len(parts) >= 3 { - deps = append(deps, parts[1]) - } - } - } - return deps -} - -// classifyImport determines the type of an import path. -func classifyImport(importPath, moduleName string, externalDeps []string) string { - // Internal: starts with module name. - if strings.HasPrefix(importPath, moduleName) { - return "internal" - } - - // External: matches a known dependency. - for _, dep := range externalDeps { - if strings.HasPrefix(importPath, dep) { - return "external" - } - } - - // If it contains a dot in the first path component, it's likely external. - firstComponent := strings.SplitN(importPath, "/", 2)[0] - if strings.Contains(firstComponent, ".") { - return "external" - } - - return "stdlib" -} - -// countFileLOC counts lines of code in a file (non-blank lines). -func countFileLOC(path string) int { - f, err := os.Open(path) - if err != nil { - return 0 - } - defer func() { _ = f.Close() }() - - count := 0 - scanner := bufio.NewScanner(f) - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if line != "" { - count++ - } - } - return count -} - -// appendUniqueStr is defined in callgraph.go — reused here. - -// containsCycle checks if cycles already contains an equivalent cycle. -func containsCycle(cycles [][]string, cycle []string) bool { - for _, existing := range cycles { - if len(existing) == len(cycle) { - match := true - for i := range existing { - if existing[i] != cycle[i] { - match = false - break - } - } - if match { - return true - } - } - } - return false -} - -// isNodeBuiltin checks if a package name is a Node.js built-in module. -func isNodeBuiltin(name string) bool { - builtins := map[string]bool{ - "fs": true, "path": true, "os": true, "http": true, "https": true, - "crypto": true, "util": true, "events": true, "stream": true, - "child_process": true, "net": true, "url": true, "querystring": true, - "buffer": true, "assert": true, "cluster": true, "dns": true, - "readline": true, "tls": true, "zlib": true, "vm": true, - "process": true, "module": true, "console": true, "timers": true, - } - // Also handle "node:" prefix. - if strings.HasPrefix(name, "node:") { - return true - } - return builtins[name] -} diff --git a/internal/intelligence/repomap/depgraph_analysis.go b/internal/intelligence/repomap/depgraph_analysis.go new file mode 100644 index 00000000..73bc0c9c --- /dev/null +++ b/internal/intelligence/repomap/depgraph_analysis.go @@ -0,0 +1,553 @@ +package repomap + +import "sort" + +// This file holds the graph traversal algorithms (topological sort, cycle +// detection, layering, hot-path scoring) and their unlocked helpers. The graph +// type, mutators, and renderers live in depgraph.go; the builders live in +// depgraph_build.go. + +// TopologicalSort returns packages in dependency order (leaves first). +// Packages with no dependencies appear first. +func (dg *DepGraph) TopologicalSort() []string { + dg.mu.RLock() + defer dg.mu.RUnlock() + + // Build adjacency list and in-degree count. + inDegree := make(map[string]int) + adj := make(map[string][]string) + + for id := range dg.Nodes { + inDegree[id] = 0 + } + for _, edge := range dg.Edges { + if _, ok := dg.Nodes[edge.From]; !ok { + continue + } + if _, ok := dg.Nodes[edge.To]; !ok { + continue + } + adj[edge.From] = append(adj[edge.From], edge.To) + inDegree[edge.To]++ + } + + // Kahn's algorithm. + var queue []string + for id, deg := range inDegree { + if deg == 0 { + queue = append(queue, id) + } + } + sort.Strings(queue) // deterministic order + + for len(queue) > 0 { + sort.Strings(queue) + node := queue[0] + queue = queue[1:] + + for _, neighbor := range adj[node] { + inDegree[neighbor]-- + if inDegree[neighbor] == 0 { + queue = append(queue, neighbor) + } + } + } + // For "leaves first" (packages with no dependencies), we reverse the edge direction. + + // Re-do with reversed edges: nodes that IMPORT nothing come first. + outDegree := make(map[string]int) + revAdj := make(map[string][]string) + for id := range dg.Nodes { + outDegree[id] = 0 + } + for _, edge := range dg.Edges { + if _, ok := dg.Nodes[edge.From]; !ok { + continue + } + if _, ok := dg.Nodes[edge.To]; !ok { + continue + } + outDegree[edge.From]++ + revAdj[edge.To] = append(revAdj[edge.To], edge.From) + } + + // Collect leaves (nodes with no outgoing edges = no imports). + var leafQueue []string + for id, deg := range outDegree { + if deg == 0 { + leafQueue = append(leafQueue, id) + } + } + sort.Strings(leafQueue) + + visited := make(map[string]bool) + var sorted []string + for len(leafQueue) > 0 { + sort.Strings(leafQueue) + node := leafQueue[0] + leafQueue = leafQueue[1:] + if visited[node] { + continue + } + visited[node] = true + sorted = append(sorted, node) + + for _, parent := range revAdj[node] { + // Check if all of parent's dependencies are visited. + allVisited := true + for _, e := range dg.Edges { + if e.From == parent { + if _, ok := dg.Nodes[e.To]; ok { + if !visited[e.To] { + allVisited = false + break + } + } + } + } + if allVisited && !visited[parent] { + leafQueue = append(leafQueue, parent) + } + } + } + + // Add any remaining nodes (part of cycles) at the end. + for id := range dg.Nodes { + if !visited[id] { + sorted = append(sorted, id) + } + } + + return sorted +} + +// FindCycles detects circular dependencies and returns all cycles found. +func (dg *DepGraph) FindCycles() [][]string { + dg.mu.RLock() + defer dg.mu.RUnlock() + + // Build adjacency list. + adj := make(map[string][]string) + for _, edge := range dg.Edges { + if _, ok := dg.Nodes[edge.From]; !ok { + continue + } + if _, ok := dg.Nodes[edge.To]; !ok { + continue + } + adj[edge.From] = append(adj[edge.From], edge.To) + } + + // Sort adjacency lists for determinism. + for k := range adj { + sort.Strings(adj[k]) + } + + // Johnson's algorithm simplified: DFS-based cycle detection. + var cycles [][]string + visited := make(map[string]bool) + onStack := make(map[string]bool) + var path []string + + var dfs func(node string) + dfs = func(node string) { + visited[node] = true + onStack[node] = true + path = append(path, node) + + for _, next := range adj[node] { + if !visited[next] { + dfs(next) + } else if onStack[next] { + // Found a cycle: extract it from path. + cycleStart := -1 + for i, p := range path { + if p == next { + cycleStart = i + break + } + } + if cycleStart >= 0 { + cycle := make([]string, len(path)-cycleStart) + copy(cycle, path[cycleStart:]) + // Check for duplicates. + if !containsCycle(cycles, cycle) { + cycles = append(cycles, cycle) + } + } + } + } + + path = path[:len(path)-1] + onStack[node] = false + } + + // Sort nodes for deterministic order. + nodeIDs := make([]string, 0, len(dg.Nodes)) + for id := range dg.Nodes { + nodeIDs = append(nodeIDs, id) + } + sort.Strings(nodeIDs) + + for _, id := range nodeIDs { + if !visited[id] { + dfs(id) + } + } + + return cycles +} + +// Layers groups packages into layers based on dependency depth. +// Layer 0 contains packages with no dependencies, layer 1 depends only on layer 0, etc. +func (dg *DepGraph) Layers() [][]string { + dg.mu.RLock() + defer dg.mu.RUnlock() + + if len(dg.Nodes) == 0 { + return nil + } + + // Build adjacency (from -> deps). + deps := make(map[string][]string) + for _, edge := range dg.Edges { + if _, ok := dg.Nodes[edge.From]; !ok { + continue + } + if _, ok := dg.Nodes[edge.To]; !ok { + continue + } + deps[edge.From] = append(deps[edge.From], edge.To) + } + + layerOf := make(map[string]int) + var computeLayer func(id string, visiting map[string]bool) int + computeLayer = func(id string, visiting map[string]bool) int { + if l, ok := layerOf[id]; ok { + return l + } + if visiting[id] { + // Cycle detected, assign 0 to break it. + return 0 + } + visiting[id] = true + + maxDep := -1 + for _, dep := range deps[id] { + l := computeLayer(dep, visiting) + if l > maxDep { + maxDep = l + } + } + layerOf[id] = maxDep + 1 + delete(visiting, id) + return maxDep + 1 + } + + for id := range dg.Nodes { + if _, ok := layerOf[id]; !ok { + computeLayer(id, make(map[string]bool)) + } + } + + // Group by layer. + maxLayer := 0 + for _, l := range layerOf { + if l > maxLayer { + maxLayer = l + } + } + + layers := make([][]string, maxLayer+1) + for id, l := range layerOf { + layers[l] = append(layers[l], id) + } + + // Sort within each layer for determinism. + for i := range layers { + sort.Strings(layers[i]) + } + + return layers +} + +// HotPaths finds the most-depended-on paths using a PageRank-like importance scoring. +// Returns paths sorted by importance (most critical first). +func (dg *DepGraph) HotPaths() [][]string { + dg.mu.RLock() + defer dg.mu.RUnlock() + + if len(dg.Nodes) == 0 { + return nil + } + + // Compute PageRank-like scores. + scores := make(map[string]float64) + n := float64(len(dg.Nodes)) + for id := range dg.Nodes { + scores[id] = 1.0 / n + } + + // Build inbound edges. + inbound := make(map[string][]string) + outCount := make(map[string]int) + for _, edge := range dg.Edges { + if _, ok := dg.Nodes[edge.From]; !ok { + continue + } + if _, ok := dg.Nodes[edge.To]; !ok { + continue + } + inbound[edge.To] = append(inbound[edge.To], edge.From) + outCount[edge.From]++ + } + + damping := 0.85 + for iter := 0; iter < 20; iter++ { + newScores := make(map[string]float64) + for id := range dg.Nodes { + sum := 0.0 + for _, src := range inbound[id] { + if outCount[src] > 0 { + sum += scores[src] / float64(outCount[src]) + } + } + newScores[id] = (1.0-damping)/n + damping*sum + } + scores = newScores + } + + // Sort nodes by score descending. + type scoredNode struct { + id string + score float64 + } + var ranked []scoredNode + for id, score := range scores { + ranked = append(ranked, scoredNode{id, score}) + } + sort.Slice(ranked, func(i, j int) bool { + if ranked[i].score == ranked[j].score { + return ranked[i].id < ranked[j].id + } + return ranked[i].score > ranked[j].score + }) + + // Build paths from top-ranked nodes following their heaviest dependency chains. + adj := make(map[string][]string) + edgeWeight := make(map[string]int) // "from->to" => weight + for _, edge := range dg.Edges { + if _, ok := dg.Nodes[edge.From]; !ok { + continue + } + if _, ok := dg.Nodes[edge.To]; !ok { + continue + } + adj[edge.From] = append(adj[edge.From], edge.To) + key := edge.From + "->" + edge.To + edgeWeight[key] = edge.Weight + } + + var paths [][]string + used := make(map[string]bool) + + for _, rn := range ranked { + if used[rn.id] { + continue + } + if len(paths) >= 5 { + break + } + + // Follow the heaviest outgoing chain. + path := []string{rn.id} + current := rn.id + visited := map[string]bool{current: true} + + for { + neighbors := adj[current] + if len(neighbors) == 0 { + break + } + // Pick the heaviest edge. + bestNeighbor := "" + bestWeight := 0 + for _, nb := range neighbors { + if visited[nb] { + continue + } + key := current + "->" + nb + w := edgeWeight[key] + if w > bestWeight || (w == bestWeight && nb < bestNeighbor) { + bestWeight = w + bestNeighbor = nb + } + } + if bestNeighbor == "" { + break + } + path = append(path, bestNeighbor) + visited[bestNeighbor] = true + current = bestNeighbor + } + + if len(path) > 1 { + paths = append(paths, path) + for _, p := range path { + used[p] = true + } + } + } + + return paths +} + +// layersUnlocked computes layers without holding the lock (caller must hold RLock). +func (dg *DepGraph) layersUnlocked() [][]string { + if len(dg.Nodes) == 0 { + return nil + } + + deps := make(map[string][]string) + for _, edge := range dg.Edges { + if _, ok := dg.Nodes[edge.From]; !ok { + continue + } + if _, ok := dg.Nodes[edge.To]; !ok { + continue + } + deps[edge.From] = append(deps[edge.From], edge.To) + } + + layerOf := make(map[string]int) + var computeLayer func(id string, visiting map[string]bool) int + computeLayer = func(id string, visiting map[string]bool) int { + if l, ok := layerOf[id]; ok { + return l + } + if visiting[id] { + return 0 + } + visiting[id] = true + + maxDep := -1 + for _, dep := range deps[id] { + l := computeLayer(dep, visiting) + if l > maxDep { + maxDep = l + } + } + layerOf[id] = maxDep + 1 + delete(visiting, id) + return maxDep + 1 + } + + for id := range dg.Nodes { + if _, ok := layerOf[id]; !ok { + computeLayer(id, make(map[string]bool)) + } + } + + maxLayer := 0 + for _, l := range layerOf { + if l > maxLayer { + maxLayer = l + } + } + + layers := make([][]string, maxLayer+1) + for id, l := range layerOf { + layers[l] = append(layers[l], id) + } + for i := range layers { + sort.Strings(layers[i]) + } + + return layers +} + +// findCyclesUnlocked detects cycles without holding the lock (caller must hold RLock). +func (dg *DepGraph) findCyclesUnlocked() [][]string { + adj := make(map[string][]string) + for _, edge := range dg.Edges { + if _, ok := dg.Nodes[edge.From]; !ok { + continue + } + if _, ok := dg.Nodes[edge.To]; !ok { + continue + } + adj[edge.From] = append(adj[edge.From], edge.To) + } + for k := range adj { + sort.Strings(adj[k]) + } + + var cycles [][]string + visited := make(map[string]bool) + onStack := make(map[string]bool) + var path []string + + var dfs func(node string) + dfs = func(node string) { + visited[node] = true + onStack[node] = true + path = append(path, node) + + for _, next := range adj[node] { + if !visited[next] { + dfs(next) + } else if onStack[next] { + cycleStart := -1 + for i, p := range path { + if p == next { + cycleStart = i + break + } + } + if cycleStart >= 0 { + cycle := make([]string, len(path)-cycleStart) + copy(cycle, path[cycleStart:]) + if !containsCycle(cycles, cycle) { + cycles = append(cycles, cycle) + } + } + } + } + + path = path[:len(path)-1] + onStack[node] = false + } + + nodeIDs := make([]string, 0, len(dg.Nodes)) + for id := range dg.Nodes { + nodeIDs = append(nodeIDs, id) + } + sort.Strings(nodeIDs) + + for _, id := range nodeIDs { + if !visited[id] { + dfs(id) + } + } + + return cycles +} + +// containsCycle checks if cycles already contains an equivalent cycle. +func containsCycle(cycles [][]string, cycle []string) bool { + for _, existing := range cycles { + if len(existing) == len(cycle) { + match := true + for i := range existing { + if existing[i] != cycle[i] { + match = false + break + } + } + if match { + return true + } + } + } + return false +} diff --git a/internal/intelligence/repomap/depgraph_build.go b/internal/intelligence/repomap/depgraph_build.go new file mode 100644 index 00000000..132d6957 --- /dev/null +++ b/internal/intelligence/repomap/depgraph_build.go @@ -0,0 +1,479 @@ +package repomap + +import ( + "bufio" + "encoding/json" + "fmt" + "go/parser" + "go/token" + "os" + "path/filepath" + "regexp" + "strings" +) + +// This file holds the graph builders (Go via go.mod + go/parser, JS/TS via +// package.json + import regexes) and their parsing helpers. The graph type and +// renderers live in depgraph.go; traversal algorithms live in +// depgraph_analysis.go. + +// BuildFromGoMod reads go.mod and scans .go files to build the dependency graph. +func (dg *DepGraph) BuildFromGoMod(projectDir string) error { + dg.mu.Lock() + defer dg.mu.Unlock() + + goModPath := filepath.Join(projectDir, "go.mod") + modData, err := os.ReadFile(goModPath) + if err != nil { + return fmt.Errorf("depgraph: read go.mod: %w", err) + } + + moduleName := parseModuleName(string(modData)) + if moduleName == "" { + return fmt.Errorf("depgraph: cannot determine module name from go.mod") + } + dg.Root = moduleName + + // Parse external dependencies from go.mod require blocks. + externalDeps := parseGoModRequires(string(modData)) + + // Add external dependency nodes. + for _, dep := range externalDeps { + shortName := filepath.Base(dep) + dg.Nodes[dep] = &DepNode{ + ID: dep, + Name: shortName, + Type: "external", + ImportedBy: []string{}, + Imports: []string{}, + } + } + + // Scan all .go files to collect imports and build internal packages. + internalPkgs := make(map[string]*DepNode) + // pkgImports maps each internal package path to a set of import paths. + pkgImports := make(map[string]map[string]bool) + + fset := token.NewFileSet() + err = filepath.Walk(projectDir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return nil + } + if info.IsDir() { + base := filepath.Base(path) + if base == "vendor" || base == ".git" || base == "testdata" { + return filepath.SkipDir + } + return nil + } + if !strings.HasSuffix(path, ".go") { + return nil + } + // Skip test files for dependency analysis. + if strings.HasSuffix(path, "_test.go") { + return nil + } + + f, parseErr := parser.ParseFile(fset, path, nil, parser.ImportsOnly) + if parseErr != nil { + return nil + } + + relDir, _ := filepath.Rel(projectDir, filepath.Dir(path)) + if relDir == "" || relDir == "." { + relDir = "" + } + var pkgPath string + if relDir == "" { + pkgPath = moduleName + } else { + pkgPath = moduleName + "/" + filepath.ToSlash(relDir) + } + + if _, ok := internalPkgs[pkgPath]; !ok { + shortName := filepath.Base(pkgPath) + if pkgPath == moduleName { + shortName = filepath.Base(moduleName) + } + internalPkgs[pkgPath] = &DepNode{ + ID: pkgPath, + Name: shortName, + Type: "internal", + FileCount: 0, + LOC: 0, + ImportedBy: []string{}, + Imports: []string{}, + } + pkgImports[pkgPath] = make(map[string]bool) + } + + internalPkgs[pkgPath].FileCount++ + + // Count LOC. + loc := countFileLOC(path) + internalPkgs[pkgPath].LOC += loc + + // Collect imports. + for _, imp := range f.Imports { + impPath := strings.Trim(imp.Path.Value, `"`) + pkgImports[pkgPath][impPath] = true + } + + return nil + }) + if err != nil { + return fmt.Errorf("depgraph: walk project: %w", err) + } + + // Add internal package nodes. + for id, node := range internalPkgs { + dg.Nodes[id] = node + } + + // Process imports and create edges. + for pkgPath, imports := range pkgImports { + for imp := range imports { + impType := classifyImport(imp, moduleName, externalDeps) + + // Ensure stdlib nodes exist. + if impType == "stdlib" { + if _, ok := dg.Nodes[imp]; !ok { + dg.Nodes[imp] = &DepNode{ + ID: imp, + Name: filepath.Base(imp), + Type: "stdlib", + ImportedBy: []string{}, + Imports: []string{}, + } + } + } + + // Record the import relationship. + if node, ok := dg.Nodes[pkgPath]; ok { + node.Imports = appendUniqueStr(node.Imports, imp) + } + if node, ok := dg.Nodes[imp]; ok { + node.ImportedBy = appendUniqueStr(node.ImportedBy, pkgPath) + } + + // Add edge. + found := false + for i, e := range dg.Edges { + if e.From == pkgPath && e.To == imp { + dg.Edges[i].Weight++ + found = true + break + } + } + if !found { + dg.Edges = append(dg.Edges, DepEdge{ + From: pkgPath, + To: imp, + Weight: 1, + }) + } + } + } + + return nil +} + +// BuildFromPackageJSON reads package.json and scans JS/TS files to build the +// dependency graph. +func (dg *DepGraph) BuildFromPackageJSON(projectDir string) error { + dg.mu.Lock() + defer dg.mu.Unlock() + + pkgJSONPath := filepath.Join(projectDir, "package.json") + data, err := os.ReadFile(pkgJSONPath) + if err != nil { + return fmt.Errorf("depgraph: read package.json: %w", err) + } + + var pkgJSON struct { + Name string `json:"name"` + Dependencies map[string]string `json:"dependencies"` + DevDependencies map[string]string `json:"devDependencies"` + } + if unmarshalErr := json.Unmarshal(data, &pkgJSON); unmarshalErr != nil { + return fmt.Errorf("depgraph: parse package.json: %w", unmarshalErr) + } + + dg.Root = pkgJSON.Name + + // Add the root package node. + dg.Nodes[pkgJSON.Name] = &DepNode{ + ID: pkgJSON.Name, + Name: pkgJSON.Name, + Type: "internal", + ImportedBy: []string{}, + Imports: []string{}, + } + + // Collect all declared dependencies. + allDeps := make(map[string]bool) + for dep := range pkgJSON.Dependencies { + allDeps[dep] = true + dg.Nodes[dep] = &DepNode{ + ID: dep, + Name: dep, + Type: "external", + ImportedBy: []string{}, + Imports: []string{}, + } + } + for dep := range pkgJSON.DevDependencies { + allDeps[dep] = true + if _, ok := dg.Nodes[dep]; !ok { + dg.Nodes[dep] = &DepNode{ + ID: dep, + Name: dep, + Type: "external", + ImportedBy: []string{}, + Imports: []string{}, + } + } + } + + // Scan JS/TS files for imports. + jsImportRe := regexp.MustCompile(`(?:import\s+.*?\s+from\s+['"]([^'"]+)['"]|require\s*\(\s*['"]([^'"]+)['"]\s*\))`) + + // Internal modules map (relative imports). + internalModules := make(map[string]*DepNode) + + err = filepath.Walk(projectDir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return nil + } + if info.IsDir() { + base := filepath.Base(path) + if base == "node_modules" || base == ".git" || base == "dist" || base == "build" { + return filepath.SkipDir + } + return nil + } + ext := filepath.Ext(path) + if ext != ".js" && ext != ".ts" && ext != ".jsx" && ext != ".tsx" { + return nil + } + + relPath, _ := filepath.Rel(projectDir, path) + relPath = filepath.ToSlash(relPath) + + // Determine the "module" path for this file. + modPath := pkgJSON.Name + "/" + relPath + + if _, ok := internalModules[modPath]; !ok { + internalModules[modPath] = &DepNode{ + ID: modPath, + Name: filepath.Base(relPath), + Type: "internal", + FileCount: 1, + LOC: 0, + ImportedBy: []string{}, + Imports: []string{}, + } + } + internalModules[modPath].LOC += countFileLOC(path) + + // Read file and find imports. + content, readErr := os.ReadFile(path) + if readErr != nil { + return nil + } + + matches := jsImportRe.FindAllStringSubmatch(string(content), -1) + for _, match := range matches { + imported := match[1] + if imported == "" { + imported = match[2] + } + if imported == "" { + continue + } + + var targetID string + if strings.HasPrefix(imported, ".") { + // Relative import => internal. + dir := filepath.Dir(relPath) + resolved := filepath.ToSlash(filepath.Join(dir, imported)) + targetID = pkgJSON.Name + "/" + resolved + if _, ok := internalModules[targetID]; !ok { + internalModules[targetID] = &DepNode{ + ID: targetID, + Name: filepath.Base(imported), + Type: "internal", + ImportedBy: []string{}, + Imports: []string{}, + } + } + } else { + // Package import => external. + // Extract package name (handle scoped packages). + pkgName := imported + if strings.HasPrefix(imported, "@") { + parts := strings.SplitN(imported, "/", 3) + if len(parts) >= 2 { + pkgName = parts[0] + "/" + parts[1] + } + } else { + parts := strings.SplitN(imported, "/", 2) + pkgName = parts[0] + } + targetID = pkgName + if _, ok := dg.Nodes[targetID]; !ok { + nodeType := "external" + if isNodeBuiltin(pkgName) { + nodeType = "stdlib" + } + dg.Nodes[targetID] = &DepNode{ + ID: targetID, + Name: pkgName, + Type: nodeType, + ImportedBy: []string{}, + Imports: []string{}, + } + } + } + + internalModules[modPath].Imports = appendUniqueStr(internalModules[modPath].Imports, targetID) + if node, ok := dg.Nodes[targetID]; ok { + node.ImportedBy = appendUniqueStr(node.ImportedBy, modPath) + } else if mod, ok := internalModules[targetID]; ok { + mod.ImportedBy = appendUniqueStr(mod.ImportedBy, modPath) + } + + // Add edge. + found := false + for i, e := range dg.Edges { + if e.From == modPath && e.To == targetID { + dg.Edges[i].Weight++ + found = true + break + } + } + if !found { + dg.Edges = append(dg.Edges, DepEdge{ + From: modPath, + To: targetID, + Weight: 1, + }) + } + } + + return nil + }) + if err != nil { + return fmt.Errorf("depgraph: walk project: %w", err) + } + + // Merge internal modules into the graph. + for id, node := range internalModules { + dg.Nodes[id] = node + } + + return nil +} + +// parseModuleName extracts the module name from go.mod content. +func parseModuleName(content string) string { + scanner := bufio.NewScanner(strings.NewReader(content)) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if strings.HasPrefix(line, "module ") { + return strings.TrimSpace(strings.TrimPrefix(line, "module")) + } + } + return "" +} + +// parseGoModRequires extracts dependency paths from go.mod require blocks. +func parseGoModRequires(content string) []string { + var deps []string + inRequire := false + scanner := bufio.NewScanner(strings.NewReader(content)) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if strings.HasPrefix(line, "require (") || line == "require (" { + inRequire = true + continue + } + if inRequire { + if line == ")" { + inRequire = false + continue + } + parts := strings.Fields(line) + if len(parts) >= 2 { + deps = append(deps, parts[0]) + } + } + // Single-line require. + if strings.HasPrefix(line, "require ") && !strings.Contains(line, "(") { + parts := strings.Fields(line) + if len(parts) >= 3 { + deps = append(deps, parts[1]) + } + } + } + return deps +} + +// classifyImport determines the type of an import path. +func classifyImport(importPath, moduleName string, externalDeps []string) string { + // Internal: starts with module name. + if strings.HasPrefix(importPath, moduleName) { + return "internal" + } + + // External: matches a known dependency. + for _, dep := range externalDeps { + if strings.HasPrefix(importPath, dep) { + return "external" + } + } + + // If it contains a dot in the first path component, it's likely external. + firstComponent := strings.SplitN(importPath, "/", 2)[0] + if strings.Contains(firstComponent, ".") { + return "external" + } + + return "stdlib" +} + +// countFileLOC counts lines of code in a file (non-blank lines). +func countFileLOC(path string) int { + f, err := os.Open(path) + if err != nil { + return 0 + } + defer func() { _ = f.Close() }() + + count := 0 + scanner := bufio.NewScanner(f) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line != "" { + count++ + } + } + return count +} + +// isNodeBuiltin checks if a package name is a Node.js built-in module. +func isNodeBuiltin(name string) bool { + builtins := map[string]bool{ + "fs": true, "path": true, "os": true, "http": true, "https": true, + "crypto": true, "util": true, "events": true, "stream": true, + "child_process": true, "net": true, "url": true, "querystring": true, + "buffer": true, "assert": true, "cluster": true, "dns": true, + "readline": true, "tls": true, "zlib": true, "vm": true, + "process": true, "module": true, "console": true, "timers": true, + } + // Also handle "node:" prefix. + if strings.HasPrefix(name, "node:") { + return true + } + return builtins[name] +} From 08354ae5d740d2063a85fe0717998a2023832a44 Mon Sep 17 00:00:00 2001 From: Lakshman Patel Date: Fri, 19 Jun 2026 19:02:32 +0530 Subject: [PATCH 06/20] refactor(project): split project_analyzer.go into metrics, patterns, and report files --- internal/engine/project/project_analyzer.go | 731 +------------------- internal/engine/project/project_metrics.go | 266 +++++++ internal/engine/project/project_patterns.go | 373 ++++++++++ internal/engine/project/project_report.go | 124 ++++ 4 files changed, 771 insertions(+), 723 deletions(-) create mode 100644 internal/engine/project/project_metrics.go create mode 100644 internal/engine/project/project_patterns.go create mode 100644 internal/engine/project/project_report.go diff --git a/internal/engine/project/project_analyzer.go b/internal/engine/project/project_analyzer.go index 3cd83462..109b5814 100644 --- a/internal/engine/project/project_analyzer.go +++ b/internal/engine/project/project_analyzer.go @@ -14,6 +14,12 @@ import ( "sync" ) +// ProjectAnalyzer performs deep analysis of a codebase. This file holds the +// analyzer type, the top-level Analyze orchestration, per-module analysis, and +// the detection methods. Quantitative metrics live in project_metrics.go, +// architecture/pattern detection in project_patterns.go, and the report +// formatters in project_report.go. + // ProjectAnalysis holds the full analysis of a project's architecture, patterns, and conventions. type ProjectAnalysis struct { Name string @@ -109,154 +115,6 @@ func (pa *ProjectAnalyzer) Analyze() (*ProjectAnalysis, error) { return analysis, nil } -// DetectArchitecture determines the architectural style of a project by examining -// its directory structure. -func DetectArchitecture(dir string) string { - entries, err := os.ReadDir(dir) - if err != nil { - return "unknown" - } - - dirs := make(map[string]bool) - for _, entry := range entries { - if entry.IsDir() && !strings.HasPrefix(entry.Name(), ".") { - dirs[entry.Name()] = true - } - } - - // Hexagonal: domain/ + ports/ + adapters/ - if dirs["domain"] && dirs["ports"] && dirs["adapters"] { - return "hexagonal" - } - - // Microservices: multiple service directories. - serviceCount := 0 - for name := range dirs { - if strings.HasSuffix(name, "-service") || strings.HasSuffix(name, "-svc") { - serviceCount++ - } - } - if dirs["services"] || serviceCount >= 2 { - return "microservices" - } - - // Layered: cmd/ -> service/ or internal/ -> repository/ or repo/ - if dirs["cmd"] && (dirs["service"] || dirs["internal"] || dirs["engine"]) { - if dirs["repo"] || dirs["repository"] || dirs["store"] || dirs["tool"] { - return "layered" - } - return "layered" - } - - // Modular: feature-based directories (more than 4 sibling directories with similar structure). - featureDirs := 0 - for name := range dirs { - subPath := filepath.Join(dir, name) - if hasGoFiles(subPath) { - featureDirs++ - } - } - if featureDirs >= 5 && !dirs["cmd"] { - return "modular" - } - - // Monolith: single main package. - if hasMainPackage(dir) && featureDirs <= 2 { - return "monolith" - } - - // Default to modular if there are many subdirectories. - if featureDirs >= 4 { - return "modular" - } - - return "monolith" -} - -// DetectPatterns identifies design patterns used in the codebase. -func DetectPatterns(dir string) []Pattern { - var patterns []Pattern - - // Repository pattern: *Repository interfaces + implementations. - repoFiles := findFilesWithPattern(dir, "repository", "repo") - if len(repoFiles) > 0 { - patterns = append(patterns, Pattern{ - Name: "Repository", - Description: "Data access abstracted behind repository interfaces", - Files: repoFiles, - Confidence: calculateConfidence(repoFiles, 2), - }) - } - - // Middleware pattern: handler wrappers, interceptors. - middlewareFiles := findFilesWithPattern(dir, "middleware", "interceptor") - if len(middlewareFiles) > 0 { - patterns = append(patterns, Pattern{ - Name: "Middleware", - Description: "Request/response processing chain with handler wrappers", - Files: middlewareFiles, - Confidence: calculateConfidence(middlewareFiles, 2), - }) - } - - // Factory pattern: New* constructors. - factoryFiles := findFactoryPattern(dir) - if len(factoryFiles) > 0 { - patterns = append(patterns, Pattern{ - Name: "Factory", - Description: "Object creation via New* constructor functions", - Files: factoryFiles, - Confidence: calculateConfidence(factoryFiles, 5), - }) - } - - // Observer pattern: event/listener files. - observerFiles := findFilesWithPattern(dir, "event", "listener", "observer", "hook") - if len(observerFiles) > 0 { - patterns = append(patterns, Pattern{ - Name: "Observer", - Description: "Event-driven communication with listeners/hooks", - Files: observerFiles, - Confidence: calculateConfidence(observerFiles, 2), - }) - } - - // Strategy pattern: interface + multiple implementations. - strategyFiles := findStrategyPattern(dir) - if len(strategyFiles) > 0 { - patterns = append(patterns, Pattern{ - Name: "Strategy", - Description: "Interface with multiple interchangeable implementations", - Files: strategyFiles, - Confidence: calculateConfidence(strategyFiles, 3), - }) - } - - // Interface-driven tools pattern. - toolFiles := findFilesWithPattern(dir, "tool") - if len(toolFiles) >= 3 { - patterns = append(patterns, Pattern{ - Name: "Interface-driven tools", - Description: "Tool interface with multiple implementations", - Files: toolFiles, - Confidence: calculateConfidence(toolFiles, 5), - }) - } - - // Functional options pattern (WithXxx). - optionFiles := findFunctionalOptionsPattern(dir) - if len(optionFiles) > 0 { - patterns = append(patterns, Pattern{ - Name: "Functional Options", - Description: "Configuration via WithXxx option functions", - Files: optionFiles, - Confidence: calculateConfidence(optionFiles, 3), - }) - } - - return patterns -} - // AnalyzeModule scans a package directory and extracts its public API, line count, and purpose. func (pa *ProjectAnalyzer) AnalyzeModule(path string) *ModuleInfo { info := &ModuleInfo{ @@ -332,108 +190,7 @@ func (pa *ProjectAnalyzer) AnalyzeModule(path string) *ModuleInfo { return info } -// GenerateOnboardingDoc produces a human-readable onboarding document from the analysis. -func GenerateOnboardingDoc(analysis *ProjectAnalysis) string { - var b strings.Builder - - b.WriteString(fmt.Sprintf("# Project: %s\n\n", analysis.Name)) - - // Architecture section. - b.WriteString(fmt.Sprintf("## Architecture: %s\n", projAnalyzerTitle(analysis.Architecture))) - if len(analysis.KeyModules) > 0 { - moduleNames := make([]string, 0, len(analysis.KeyModules)) - for _, m := range analysis.KeyModules { - moduleNames = append(moduleNames, m.Name) - } - b.WriteString(strings.Join(moduleNames, " -> ")) - b.WriteString("\n") - } - b.WriteString("\n") - - // Key modules section. - b.WriteString("## Key Modules\n") - for _, m := range analysis.KeyModules { - locStr := formatLOC(m.Size) - purpose := m.Purpose - if purpose == "" { - purpose = "Core functionality" - } - b.WriteString(fmt.Sprintf("- %s (%s): %s\n", m.Name, locStr, purpose)) - } - b.WriteString("\n") - - // Patterns section. - if len(analysis.Patterns) > 0 { - b.WriteString("## Patterns Detected\n") - for _, p := range analysis.Patterns { - if p.Confidence >= 0.5 { - b.WriteString(fmt.Sprintf("- %s (%s)\n", p.Name, p.Description)) - } - } - b.WriteString("\n") - } - - // Conventions section. - if len(analysis.Conventions) > 0 { - b.WriteString("## Conventions\n") - for _, c := range analysis.Conventions { - b.WriteString(fmt.Sprintf("- %s\n", c)) - } - b.WriteString("\n") - } - - // Stats section. - b.WriteString("## Stats\n") - b.WriteString(fmt.Sprintf("- Language: %s\n", analysis.Language)) - if analysis.Framework != "" { - b.WriteString(fmt.Sprintf("- Framework: %s\n", analysis.Framework)) - } - b.WriteString(fmt.Sprintf("- Total LOC: %s\n", formatLOC(analysis.LOC))) - b.WriteString(fmt.Sprintf("- Dependencies: %d\n", analysis.Dependencies)) - b.WriteString(fmt.Sprintf("- Test Coverage: %s\n", analysis.TestCoverage)) - b.WriteString(fmt.Sprintf("- Complexity: %s\n", analysis.Complexity)) - - return b.String() -} - -// FormatAnalysis produces a concise summary string from a ProjectAnalysis. -func FormatAnalysis(analysis *ProjectAnalysis) string { - var b strings.Builder - - b.WriteString(fmt.Sprintf("Project: %s (%s", analysis.Name, analysis.Language)) - if analysis.Framework != "" { - b.WriteString(fmt.Sprintf(" / %s", analysis.Framework)) - } - b.WriteString(")\n") - - b.WriteString(fmt.Sprintf("Architecture: %s\n", analysis.Architecture)) - b.WriteString(fmt.Sprintf("LOC: %s | Deps: %d | Tests: %s | Complexity: %s\n", - formatLOC(analysis.LOC), analysis.Dependencies, analysis.TestCoverage, analysis.Complexity)) - - if len(analysis.EntryPoints) > 0 { - b.WriteString(fmt.Sprintf("Entry Points: %s\n", strings.Join(analysis.EntryPoints, ", "))) - } - - if len(analysis.KeyModules) > 0 { - b.WriteString(fmt.Sprintf("Modules: %d key modules\n", len(analysis.KeyModules))) - } - - if len(analysis.Patterns) > 0 { - patternNames := make([]string, 0, len(analysis.Patterns)) - for _, p := range analysis.Patterns { - if p.Confidence >= 0.5 { - patternNames = append(patternNames, p.Name) - } - } - if len(patternNames) > 0 { - b.WriteString(fmt.Sprintf("Patterns: %s\n", strings.Join(patternNames, ", "))) - } - } - - return b.String() -} - -// --- Private helper methods --- +// --- Detection methods --- func (pa *ProjectAnalyzer) detectProjectName() string { // Try go.mod first. @@ -664,142 +421,6 @@ func (pa *ProjectAnalyzer) detectConventions() []string { return conventions } -func (pa *ProjectAnalyzer) countDependencies() int { - modPath := filepath.Join(pa.Dir, "go.mod") - data, err := os.ReadFile(modPath) - if err != nil { - return 0 - } - - count := 0 - inRequire := false - for _, line := range strings.Split(string(data), "\n") { - line = strings.TrimSpace(line) - if strings.HasPrefix(line, "require (") { - inRequire = true - continue - } - if inRequire && line == ")" { - inRequire = false - continue - } - if inRequire && line != "" && !strings.HasPrefix(line, "//") { - count++ - } - // Single-line require. - if strings.HasPrefix(line, "require ") && !strings.Contains(line, "(") { - count++ - } - } - - return count -} - -func (pa *ProjectAnalyzer) countLOC() int { - total := 0 - _ = filepath.WalkDir(pa.Dir, func(path string, d fs.DirEntry, err error) error { - if err != nil { - return err - } - if d.IsDir() { - base := filepath.Base(path) - if base == "vendor" || base == "node_modules" || base == ".git" || strings.HasPrefix(base, ".") { - return filepath.SkipDir - } - return nil - } - if strings.HasSuffix(path, ".go") { - total += countFileLines(path) - } - return nil - }) - return total -} - -func (pa *ProjectAnalyzer) assessTestCoverage() string { - totalPkgs := 0 - testedPkgs := 0 - - _ = filepath.WalkDir(pa.Dir, func(path string, d fs.DirEntry, err error) error { - if err != nil { - return err - } - if !d.IsDir() { - return nil - } - base := filepath.Base(path) - if base == "vendor" || base == ".git" || strings.HasPrefix(base, ".") || base == "testdata" { - return filepath.SkipDir - } - if hasGoFiles(path) { - totalPkgs++ - if hasTestFiles(path) { - testedPkgs++ - } - } - return nil - }) - - if totalPkgs == 0 { - return "unknown" - } - - pct := float64(testedPkgs) / float64(totalPkgs) * 100 - return fmt.Sprintf("%.0f%% (%d/%d packages have tests)", pct, testedPkgs, totalPkgs) -} - -func (pa *ProjectAnalyzer) assessComplexity() string { - totalFuncs := 0 - longFuncs := 0 // functions > 50 lines - - _ = filepath.WalkDir(pa.Dir, func(path string, d fs.DirEntry, err error) error { - if err != nil { - return err - } - if d.IsDir() { - base := filepath.Base(path) - if base == "vendor" || base == ".git" || strings.HasPrefix(base, ".") { - return filepath.SkipDir - } - return nil - } - if !strings.HasSuffix(path, ".go") || strings.HasSuffix(path, "_test.go") { - return nil - } - - fset := token.NewFileSet() - f, parseErr := parser.ParseFile(fset, path, nil, 0) - if parseErr != nil { - return nil - } - - for _, decl := range f.Decls { - if fn, ok := decl.(*ast.FuncDecl); ok { - totalFuncs++ - startLine := fset.Position(fn.Pos()).Line - endLine := fset.Position(fn.End()).Line - if endLine-startLine > 50 { - longFuncs++ - } - } - } - return nil - }) - - if totalFuncs == 0 { - return "unknown" - } - - longPct := float64(longFuncs) / float64(totalFuncs) * 100 - if longPct > 20 { - return fmt.Sprintf("high (%d/%d functions >50 lines)", longFuncs, totalFuncs) - } - if longPct > 10 { - return fmt.Sprintf("moderate (%d/%d functions >50 lines)", longFuncs, totalFuncs) - } - return fmt.Sprintf("low (%d/%d functions >50 lines)", longFuncs, totalFuncs) -} - func (pa *ProjectAnalyzer) inferPurpose(info *ModuleInfo) string { name := strings.ToLower(info.Name) @@ -856,329 +477,7 @@ func (pa *ProjectAnalyzer) inferPurpose(info *ModuleInfo) string { return "Module functionality" } -func (pa *ProjectAnalyzer) hasPatternInFiles(pattern string) bool { - found := false - _ = filepath.WalkDir(pa.Dir, func(path string, d fs.DirEntry, err error) error { - if found || err != nil { - return filepath.SkipAll - } - if d.IsDir() { - base := filepath.Base(path) - if base == "vendor" || base == ".git" || strings.HasPrefix(base, ".") { - return filepath.SkipDir - } - return nil - } - if !strings.HasSuffix(path, ".go") || strings.HasSuffix(path, "_test.go") { - return nil - } - data, readErr := os.ReadFile(path) - if readErr != nil { - return nil - } - if strings.Contains(string(data), pattern) { - found = true - return filepath.SkipAll - } - return nil - }) - return found -} - -func (pa *ProjectAnalyzer) hasPatternInTestFiles(pattern string) bool { - found := false - _ = filepath.WalkDir(pa.Dir, func(path string, d fs.DirEntry, err error) error { - if found || err != nil { - return filepath.SkipAll - } - if d.IsDir() { - base := filepath.Base(path) - if base == "vendor" || base == ".git" || strings.HasPrefix(base, ".") { - return filepath.SkipDir - } - return nil - } - if !strings.HasSuffix(path, "_test.go") { - return nil - } - data, readErr := os.ReadFile(path) - if readErr != nil { - return nil - } - if strings.Contains(string(data), pattern) { - found = true - return filepath.SkipAll - } - return nil - }) - return found -} - -func (pa *ProjectAnalyzer) countInterfaces() int { - count := 0 - _ = filepath.WalkDir(pa.Dir, func(path string, d fs.DirEntry, err error) error { - if err != nil { - return err - } - if d.IsDir() { - base := filepath.Base(path) - if base == "vendor" || base == ".git" || strings.HasPrefix(base, ".") { - return filepath.SkipDir - } - return nil - } - if !strings.HasSuffix(path, ".go") || strings.HasSuffix(path, "_test.go") { - return nil - } - - fset := token.NewFileSet() - f, parseErr := parser.ParseFile(fset, path, nil, 0) - if parseErr != nil { - return nil - } - - for _, decl := range f.Decls { - if gd, ok := decl.(*ast.GenDecl); ok && gd.Tok == token.TYPE { - for _, spec := range gd.Specs { - if ts, ok := spec.(*ast.TypeSpec); ok { - if _, isIface := ts.Type.(*ast.InterfaceType); isIface && ts.Name.IsExported() { - count++ - } - } - } - } - } - return nil - }) - return count -} - -// --- Package-level helper functions --- - -func hasGoFiles(dir string) bool { - entries, err := os.ReadDir(dir) - if err != nil { - return false - } - for _, entry := range entries { - if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".go") { - return true - } - } - return false -} - -func hasMainPackage(dir string) bool { - mainFile := filepath.Join(dir, "main.go") - _, err := os.Stat(mainFile) - return err == nil -} - -func countFileLines(path string) int { - f, err := os.Open(path) - if err != nil { - return 0 - } - defer func() { _ = f.Close() }() - - count := 0 - scanner := bufio.NewScanner(f) - for scanner.Scan() { - count++ - } - return count -} - -func findFilesWithPattern(dir string, patterns ...string) []string { - var files []string - _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { - if err != nil { - return err - } - if d.IsDir() { - base := filepath.Base(path) - if base == "vendor" || base == ".git" || strings.HasPrefix(base, ".") { - return filepath.SkipDir - } - return nil - } - if !strings.HasSuffix(path, ".go") || strings.HasSuffix(path, "_test.go") { - return nil - } - - lower := strings.ToLower(filepath.Base(path)) - for _, pattern := range patterns { - if strings.Contains(lower, pattern) { - relPath, relErr := filepath.Rel(dir, path) - if relErr == nil { - files = append(files, relPath) - } - break - } - } - return nil - }) - return files -} - -func findFactoryPattern(dir string) []string { - var files []string - _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { - if err != nil { - return err - } - if d.IsDir() { - base := filepath.Base(path) - if base == "vendor" || base == ".git" || strings.HasPrefix(base, ".") { - return filepath.SkipDir - } - return nil - } - if !strings.HasSuffix(path, ".go") || strings.HasSuffix(path, "_test.go") { - return nil - } - - fset := token.NewFileSet() - f, parseErr := parser.ParseFile(fset, path, nil, 0) - if parseErr != nil { - return nil - } - - newCount := 0 - for _, decl := range f.Decls { - if fn, ok := decl.(*ast.FuncDecl); ok { - if strings.HasPrefix(fn.Name.Name, "New") && fn.Name.IsExported() { - newCount++ - } - } - } - - if newCount >= 2 { - relPath, relErr := filepath.Rel(dir, path) - if relErr == nil { - files = append(files, relPath) - } - } - return nil - }) - - // Limit results. - if len(files) > 10 { - files = files[:10] - } - return files -} - -func findStrategyPattern(dir string) []string { - // Look for files that define an interface and have sibling files implementing it. - var files []string - _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { - if err != nil { - return err - } - if d.IsDir() { - base := filepath.Base(path) - if base == "vendor" || base == ".git" || strings.HasPrefix(base, ".") { - return filepath.SkipDir - } - return nil - } - if !strings.HasSuffix(path, ".go") || strings.HasSuffix(path, "_test.go") { - return nil - } - - fset := token.NewFileSet() - f, parseErr := parser.ParseFile(fset, path, nil, 0) - if parseErr != nil { - return nil - } - - // Check if the file defines interfaces with multiple methods. - for _, decl := range f.Decls { - if gd, ok := decl.(*ast.GenDecl); ok && gd.Tok == token.TYPE { - for _, spec := range gd.Specs { - if ts, ok := spec.(*ast.TypeSpec); ok { - if iface, isIface := ts.Type.(*ast.InterfaceType); isIface { - if iface.Methods != nil && len(iface.Methods.List) >= 2 { - relPath, relErr := filepath.Rel(dir, path) - if relErr == nil { - files = append(files, relPath) - } - } - } - } - } - } - } - return nil - }) - - if len(files) > 10 { - files = files[:10] - } - return files -} - -func findFunctionalOptionsPattern(dir string) []string { - var files []string - _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { - if err != nil { - return err - } - if d.IsDir() { - base := filepath.Base(path) - if base == "vendor" || base == ".git" || strings.HasPrefix(base, ".") { - return filepath.SkipDir - } - return nil - } - if !strings.HasSuffix(path, ".go") || strings.HasSuffix(path, "_test.go") { - return nil - } - - fset := token.NewFileSet() - f, parseErr := parser.ParseFile(fset, path, nil, 0) - if parseErr != nil { - return nil - } - - withCount := 0 - for _, decl := range f.Decls { - if fn, ok := decl.(*ast.FuncDecl); ok { - if strings.HasPrefix(fn.Name.Name, "With") && fn.Name.IsExported() { - withCount++ - } - } - } - - if withCount >= 3 { - relPath, relErr := filepath.Rel(dir, path) - if relErr == nil { - files = append(files, relPath) - } - } - return nil - }) - - if len(files) > 10 { - files = files[:10] - } - return files -} - -func calculateConfidence(files []string, threshold int) float64 { - count := len(files) - if count >= threshold*2 { - return 0.95 - } - if count >= threshold { - return 0.8 - } - if count >= 1 { - return 0.5 + float64(count)/float64(threshold)*0.3 - } - return 0.0 -} +// --- Small AST/string helpers --- func projAnalyzerExprToString(expr ast.Expr) string { switch t := expr.(type) { @@ -1193,13 +492,6 @@ func projAnalyzerExprToString(expr ast.Expr) string { } } -func formatLOC(loc int) string { - if loc >= 1000 { - return fmt.Sprintf("%dK LOC", loc/1000) - } - return fmt.Sprintf("%d LOC", loc) -} - func projAnalyzerAppendUnique(slice []string, s string) []string { for _, existing := range slice { if existing == s { @@ -1215,10 +507,3 @@ func projAnalyzerMin(a, b int) int { } return b } - -func projAnalyzerTitle(s string) string { - if s == "" { - return s - } - return strings.ToUpper(s[:1]) + s[1:] -} diff --git a/internal/engine/project/project_metrics.go b/internal/engine/project/project_metrics.go new file mode 100644 index 00000000..c2111419 --- /dev/null +++ b/internal/engine/project/project_metrics.go @@ -0,0 +1,266 @@ +package project + +import ( + "bufio" + "fmt" + "go/ast" + "go/parser" + "go/token" + "io/fs" + "os" + "path/filepath" + "strings" +) + +// This file holds the quantitative project metrics gathered by ProjectAnalyzer +// (dependency count, LOC, test coverage, complexity, interface count, and +// content scans). Architecture/pattern detection lives in project_patterns.go; +// the core analyzer and report formatting live in their own files. + +func (pa *ProjectAnalyzer) countDependencies() int { + modPath := filepath.Join(pa.Dir, "go.mod") + data, err := os.ReadFile(modPath) + if err != nil { + return 0 + } + + count := 0 + inRequire := false + for _, line := range strings.Split(string(data), "\n") { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "require (") { + inRequire = true + continue + } + if inRequire && line == ")" { + inRequire = false + continue + } + if inRequire && line != "" && !strings.HasPrefix(line, "//") { + count++ + } + // Single-line require. + if strings.HasPrefix(line, "require ") && !strings.Contains(line, "(") { + count++ + } + } + + return count +} + +func (pa *ProjectAnalyzer) countLOC() int { + total := 0 + _ = filepath.WalkDir(pa.Dir, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() { + base := filepath.Base(path) + if base == "vendor" || base == "node_modules" || base == ".git" || strings.HasPrefix(base, ".") { + return filepath.SkipDir + } + return nil + } + if strings.HasSuffix(path, ".go") { + total += countFileLines(path) + } + return nil + }) + return total +} + +func (pa *ProjectAnalyzer) assessTestCoverage() string { + totalPkgs := 0 + testedPkgs := 0 + + _ = filepath.WalkDir(pa.Dir, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if !d.IsDir() { + return nil + } + base := filepath.Base(path) + if base == "vendor" || base == ".git" || strings.HasPrefix(base, ".") || base == "testdata" { + return filepath.SkipDir + } + if hasGoFiles(path) { + totalPkgs++ + if hasTestFiles(path) { + testedPkgs++ + } + } + return nil + }) + + if totalPkgs == 0 { + return "unknown" + } + + pct := float64(testedPkgs) / float64(totalPkgs) * 100 + return fmt.Sprintf("%.0f%% (%d/%d packages have tests)", pct, testedPkgs, totalPkgs) +} + +func (pa *ProjectAnalyzer) assessComplexity() string { + totalFuncs := 0 + longFuncs := 0 // functions > 50 lines + + _ = filepath.WalkDir(pa.Dir, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() { + base := filepath.Base(path) + if base == "vendor" || base == ".git" || strings.HasPrefix(base, ".") { + return filepath.SkipDir + } + return nil + } + if !strings.HasSuffix(path, ".go") || strings.HasSuffix(path, "_test.go") { + return nil + } + + fset := token.NewFileSet() + f, parseErr := parser.ParseFile(fset, path, nil, 0) + if parseErr != nil { + return nil + } + + for _, decl := range f.Decls { + if fn, ok := decl.(*ast.FuncDecl); ok { + totalFuncs++ + startLine := fset.Position(fn.Pos()).Line + endLine := fset.Position(fn.End()).Line + if endLine-startLine > 50 { + longFuncs++ + } + } + } + return nil + }) + + if totalFuncs == 0 { + return "unknown" + } + + longPct := float64(longFuncs) / float64(totalFuncs) * 100 + if longPct > 20 { + return fmt.Sprintf("high (%d/%d functions >50 lines)", longFuncs, totalFuncs) + } + if longPct > 10 { + return fmt.Sprintf("moderate (%d/%d functions >50 lines)", longFuncs, totalFuncs) + } + return fmt.Sprintf("low (%d/%d functions >50 lines)", longFuncs, totalFuncs) +} + +func (pa *ProjectAnalyzer) hasPatternInFiles(pattern string) bool { + found := false + _ = filepath.WalkDir(pa.Dir, func(path string, d fs.DirEntry, err error) error { + if found || err != nil { + return filepath.SkipAll + } + if d.IsDir() { + base := filepath.Base(path) + if base == "vendor" || base == ".git" || strings.HasPrefix(base, ".") { + return filepath.SkipDir + } + return nil + } + if !strings.HasSuffix(path, ".go") || strings.HasSuffix(path, "_test.go") { + return nil + } + data, readErr := os.ReadFile(path) + if readErr != nil { + return nil + } + if strings.Contains(string(data), pattern) { + found = true + return filepath.SkipAll + } + return nil + }) + return found +} + +func (pa *ProjectAnalyzer) hasPatternInTestFiles(pattern string) bool { + found := false + _ = filepath.WalkDir(pa.Dir, func(path string, d fs.DirEntry, err error) error { + if found || err != nil { + return filepath.SkipAll + } + if d.IsDir() { + base := filepath.Base(path) + if base == "vendor" || base == ".git" || strings.HasPrefix(base, ".") { + return filepath.SkipDir + } + return nil + } + if !strings.HasSuffix(path, "_test.go") { + return nil + } + data, readErr := os.ReadFile(path) + if readErr != nil { + return nil + } + if strings.Contains(string(data), pattern) { + found = true + return filepath.SkipAll + } + return nil + }) + return found +} + +func (pa *ProjectAnalyzer) countInterfaces() int { + count := 0 + _ = filepath.WalkDir(pa.Dir, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() { + base := filepath.Base(path) + if base == "vendor" || base == ".git" || strings.HasPrefix(base, ".") { + return filepath.SkipDir + } + return nil + } + if !strings.HasSuffix(path, ".go") || strings.HasSuffix(path, "_test.go") { + return nil + } + + fset := token.NewFileSet() + f, parseErr := parser.ParseFile(fset, path, nil, 0) + if parseErr != nil { + return nil + } + + for _, decl := range f.Decls { + if gd, ok := decl.(*ast.GenDecl); ok && gd.Tok == token.TYPE { + for _, spec := range gd.Specs { + if ts, ok := spec.(*ast.TypeSpec); ok { + if _, isIface := ts.Type.(*ast.InterfaceType); isIface && ts.Name.IsExported() { + count++ + } + } + } + } + } + return nil + }) + return count +} + +func countFileLines(path string) int { + f, err := os.Open(path) + if err != nil { + return 0 + } + defer func() { _ = f.Close() }() + + count := 0 + scanner := bufio.NewScanner(f) + for scanner.Scan() { + count++ + } + return count +} diff --git a/internal/engine/project/project_patterns.go b/internal/engine/project/project_patterns.go new file mode 100644 index 00000000..893a109d --- /dev/null +++ b/internal/engine/project/project_patterns.go @@ -0,0 +1,373 @@ +package project + +import ( + "go/ast" + "go/parser" + "go/token" + "io/fs" + "os" + "path/filepath" + "strings" +) + +// This file holds architecture and design-pattern detection (DetectArchitecture, +// DetectPatterns and their scanners) plus the small structural helpers they +// share. Quantitative metrics live in project_metrics.go. + +// DetectArchitecture determines the architectural style of a project by examining +// its directory structure. +func DetectArchitecture(dir string) string { + entries, err := os.ReadDir(dir) + if err != nil { + return "unknown" + } + + dirs := make(map[string]bool) + for _, entry := range entries { + if entry.IsDir() && !strings.HasPrefix(entry.Name(), ".") { + dirs[entry.Name()] = true + } + } + + // Hexagonal: domain/ + ports/ + adapters/ + if dirs["domain"] && dirs["ports"] && dirs["adapters"] { + return "hexagonal" + } + + // Microservices: multiple service directories. + serviceCount := 0 + for name := range dirs { + if strings.HasSuffix(name, "-service") || strings.HasSuffix(name, "-svc") { + serviceCount++ + } + } + if dirs["services"] || serviceCount >= 2 { + return "microservices" + } + + // Layered: cmd/ -> service/ or internal/ -> repository/ or repo/ + if dirs["cmd"] && (dirs["service"] || dirs["internal"] || dirs["engine"]) { + if dirs["repo"] || dirs["repository"] || dirs["store"] || dirs["tool"] { + return "layered" + } + return "layered" + } + + // Modular: feature-based directories (more than 4 sibling directories with similar structure). + featureDirs := 0 + for name := range dirs { + subPath := filepath.Join(dir, name) + if hasGoFiles(subPath) { + featureDirs++ + } + } + if featureDirs >= 5 && !dirs["cmd"] { + return "modular" + } + + // Monolith: single main package. + if hasMainPackage(dir) && featureDirs <= 2 { + return "monolith" + } + + // Default to modular if there are many subdirectories. + if featureDirs >= 4 { + return "modular" + } + + return "monolith" +} + +// DetectPatterns identifies design patterns used in the codebase. +func DetectPatterns(dir string) []Pattern { + var patterns []Pattern + + // Repository pattern: *Repository interfaces + implementations. + repoFiles := findFilesWithPattern(dir, "repository", "repo") + if len(repoFiles) > 0 { + patterns = append(patterns, Pattern{ + Name: "Repository", + Description: "Data access abstracted behind repository interfaces", + Files: repoFiles, + Confidence: calculateConfidence(repoFiles, 2), + }) + } + + // Middleware pattern: handler wrappers, interceptors. + middlewareFiles := findFilesWithPattern(dir, "middleware", "interceptor") + if len(middlewareFiles) > 0 { + patterns = append(patterns, Pattern{ + Name: "Middleware", + Description: "Request/response processing chain with handler wrappers", + Files: middlewareFiles, + Confidence: calculateConfidence(middlewareFiles, 2), + }) + } + + // Factory pattern: New* constructors. + factoryFiles := findFactoryPattern(dir) + if len(factoryFiles) > 0 { + patterns = append(patterns, Pattern{ + Name: "Factory", + Description: "Object creation via New* constructor functions", + Files: factoryFiles, + Confidence: calculateConfidence(factoryFiles, 5), + }) + } + + // Observer pattern: event/listener files. + observerFiles := findFilesWithPattern(dir, "event", "listener", "observer", "hook") + if len(observerFiles) > 0 { + patterns = append(patterns, Pattern{ + Name: "Observer", + Description: "Event-driven communication with listeners/hooks", + Files: observerFiles, + Confidence: calculateConfidence(observerFiles, 2), + }) + } + + // Strategy pattern: interface + multiple implementations. + strategyFiles := findStrategyPattern(dir) + if len(strategyFiles) > 0 { + patterns = append(patterns, Pattern{ + Name: "Strategy", + Description: "Interface with multiple interchangeable implementations", + Files: strategyFiles, + Confidence: calculateConfidence(strategyFiles, 3), + }) + } + + // Interface-driven tools pattern. + toolFiles := findFilesWithPattern(dir, "tool") + if len(toolFiles) >= 3 { + patterns = append(patterns, Pattern{ + Name: "Interface-driven tools", + Description: "Tool interface with multiple implementations", + Files: toolFiles, + Confidence: calculateConfidence(toolFiles, 5), + }) + } + + // Functional options pattern (WithXxx). + optionFiles := findFunctionalOptionsPattern(dir) + if len(optionFiles) > 0 { + patterns = append(patterns, Pattern{ + Name: "Functional Options", + Description: "Configuration via WithXxx option functions", + Files: optionFiles, + Confidence: calculateConfidence(optionFiles, 3), + }) + } + + return patterns +} + +func hasGoFiles(dir string) bool { + entries, err := os.ReadDir(dir) + if err != nil { + return false + } + for _, entry := range entries { + if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".go") { + return true + } + } + return false +} + +func hasMainPackage(dir string) bool { + mainFile := filepath.Join(dir, "main.go") + _, err := os.Stat(mainFile) + return err == nil +} + +func findFilesWithPattern(dir string, patterns ...string) []string { + var files []string + _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() { + base := filepath.Base(path) + if base == "vendor" || base == ".git" || strings.HasPrefix(base, ".") { + return filepath.SkipDir + } + return nil + } + if !strings.HasSuffix(path, ".go") || strings.HasSuffix(path, "_test.go") { + return nil + } + + lower := strings.ToLower(filepath.Base(path)) + for _, pattern := range patterns { + if strings.Contains(lower, pattern) { + relPath, relErr := filepath.Rel(dir, path) + if relErr == nil { + files = append(files, relPath) + } + break + } + } + return nil + }) + return files +} + +func findFactoryPattern(dir string) []string { + var files []string + _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() { + base := filepath.Base(path) + if base == "vendor" || base == ".git" || strings.HasPrefix(base, ".") { + return filepath.SkipDir + } + return nil + } + if !strings.HasSuffix(path, ".go") || strings.HasSuffix(path, "_test.go") { + return nil + } + + fset := token.NewFileSet() + f, parseErr := parser.ParseFile(fset, path, nil, 0) + if parseErr != nil { + return nil + } + + newCount := 0 + for _, decl := range f.Decls { + if fn, ok := decl.(*ast.FuncDecl); ok { + if strings.HasPrefix(fn.Name.Name, "New") && fn.Name.IsExported() { + newCount++ + } + } + } + + if newCount >= 2 { + relPath, relErr := filepath.Rel(dir, path) + if relErr == nil { + files = append(files, relPath) + } + } + return nil + }) + + // Limit results. + if len(files) > 10 { + files = files[:10] + } + return files +} + +func findStrategyPattern(dir string) []string { + // Look for files that define an interface and have sibling files implementing it. + var files []string + _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() { + base := filepath.Base(path) + if base == "vendor" || base == ".git" || strings.HasPrefix(base, ".") { + return filepath.SkipDir + } + return nil + } + if !strings.HasSuffix(path, ".go") || strings.HasSuffix(path, "_test.go") { + return nil + } + + fset := token.NewFileSet() + f, parseErr := parser.ParseFile(fset, path, nil, 0) + if parseErr != nil { + return nil + } + + // Check if the file defines interfaces with multiple methods. + for _, decl := range f.Decls { + if gd, ok := decl.(*ast.GenDecl); ok && gd.Tok == token.TYPE { + for _, spec := range gd.Specs { + if ts, ok := spec.(*ast.TypeSpec); ok { + if iface, isIface := ts.Type.(*ast.InterfaceType); isIface { + if iface.Methods != nil && len(iface.Methods.List) >= 2 { + relPath, relErr := filepath.Rel(dir, path) + if relErr == nil { + files = append(files, relPath) + } + } + } + } + } + } + } + return nil + }) + + if len(files) > 10 { + files = files[:10] + } + return files +} + +func findFunctionalOptionsPattern(dir string) []string { + var files []string + _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() { + base := filepath.Base(path) + if base == "vendor" || base == ".git" || strings.HasPrefix(base, ".") { + return filepath.SkipDir + } + return nil + } + if !strings.HasSuffix(path, ".go") || strings.HasSuffix(path, "_test.go") { + return nil + } + + fset := token.NewFileSet() + f, parseErr := parser.ParseFile(fset, path, nil, 0) + if parseErr != nil { + return nil + } + + withCount := 0 + for _, decl := range f.Decls { + if fn, ok := decl.(*ast.FuncDecl); ok { + if strings.HasPrefix(fn.Name.Name, "With") && fn.Name.IsExported() { + withCount++ + } + } + } + + if withCount >= 3 { + relPath, relErr := filepath.Rel(dir, path) + if relErr == nil { + files = append(files, relPath) + } + } + return nil + }) + + if len(files) > 10 { + files = files[:10] + } + return files +} + +func calculateConfidence(files []string, threshold int) float64 { + count := len(files) + if count >= threshold*2 { + return 0.95 + } + if count >= threshold { + return 0.8 + } + if count >= 1 { + return 0.5 + float64(count)/float64(threshold)*0.3 + } + return 0.0 +} diff --git a/internal/engine/project/project_report.go b/internal/engine/project/project_report.go new file mode 100644 index 00000000..aceb3537 --- /dev/null +++ b/internal/engine/project/project_report.go @@ -0,0 +1,124 @@ +package project + +import ( + "fmt" + "strings" +) + +// This file holds the human-readable output formatters for a ProjectAnalysis +// (onboarding doc and concise summary) plus their small formatting helpers. + +// GenerateOnboardingDoc produces a human-readable onboarding document from the analysis. +func GenerateOnboardingDoc(analysis *ProjectAnalysis) string { + var b strings.Builder + + b.WriteString(fmt.Sprintf("# Project: %s\n\n", analysis.Name)) + + // Architecture section. + b.WriteString(fmt.Sprintf("## Architecture: %s\n", projAnalyzerTitle(analysis.Architecture))) + if len(analysis.KeyModules) > 0 { + moduleNames := make([]string, 0, len(analysis.KeyModules)) + for _, m := range analysis.KeyModules { + moduleNames = append(moduleNames, m.Name) + } + b.WriteString(strings.Join(moduleNames, " -> ")) + b.WriteString("\n") + } + b.WriteString("\n") + + // Key modules section. + b.WriteString("## Key Modules\n") + for _, m := range analysis.KeyModules { + locStr := formatLOC(m.Size) + purpose := m.Purpose + if purpose == "" { + purpose = "Core functionality" + } + b.WriteString(fmt.Sprintf("- %s (%s): %s\n", m.Name, locStr, purpose)) + } + b.WriteString("\n") + + // Patterns section. + if len(analysis.Patterns) > 0 { + b.WriteString("## Patterns Detected\n") + for _, p := range analysis.Patterns { + if p.Confidence >= 0.5 { + b.WriteString(fmt.Sprintf("- %s (%s)\n", p.Name, p.Description)) + } + } + b.WriteString("\n") + } + + // Conventions section. + if len(analysis.Conventions) > 0 { + b.WriteString("## Conventions\n") + for _, c := range analysis.Conventions { + b.WriteString(fmt.Sprintf("- %s\n", c)) + } + b.WriteString("\n") + } + + // Stats section. + b.WriteString("## Stats\n") + b.WriteString(fmt.Sprintf("- Language: %s\n", analysis.Language)) + if analysis.Framework != "" { + b.WriteString(fmt.Sprintf("- Framework: %s\n", analysis.Framework)) + } + b.WriteString(fmt.Sprintf("- Total LOC: %s\n", formatLOC(analysis.LOC))) + b.WriteString(fmt.Sprintf("- Dependencies: %d\n", analysis.Dependencies)) + b.WriteString(fmt.Sprintf("- Test Coverage: %s\n", analysis.TestCoverage)) + b.WriteString(fmt.Sprintf("- Complexity: %s\n", analysis.Complexity)) + + return b.String() +} + +// FormatAnalysis produces a concise summary string from a ProjectAnalysis. +func FormatAnalysis(analysis *ProjectAnalysis) string { + var b strings.Builder + + b.WriteString(fmt.Sprintf("Project: %s (%s", analysis.Name, analysis.Language)) + if analysis.Framework != "" { + b.WriteString(fmt.Sprintf(" / %s", analysis.Framework)) + } + b.WriteString(")\n") + + b.WriteString(fmt.Sprintf("Architecture: %s\n", analysis.Architecture)) + b.WriteString(fmt.Sprintf("LOC: %s | Deps: %d | Tests: %s | Complexity: %s\n", + formatLOC(analysis.LOC), analysis.Dependencies, analysis.TestCoverage, analysis.Complexity)) + + if len(analysis.EntryPoints) > 0 { + b.WriteString(fmt.Sprintf("Entry Points: %s\n", strings.Join(analysis.EntryPoints, ", "))) + } + + if len(analysis.KeyModules) > 0 { + b.WriteString(fmt.Sprintf("Modules: %d key modules\n", len(analysis.KeyModules))) + } + + if len(analysis.Patterns) > 0 { + patternNames := make([]string, 0, len(analysis.Patterns)) + for _, p := range analysis.Patterns { + if p.Confidence >= 0.5 { + patternNames = append(patternNames, p.Name) + } + } + if len(patternNames) > 0 { + b.WriteString(fmt.Sprintf("Patterns: %s\n", strings.Join(patternNames, ", "))) + } + } + + return b.String() +} + +func formatLOC(loc int) string { + if loc >= 1000 { + return fmt.Sprintf("%dK LOC", loc/1000) + } + return fmt.Sprintf("%d LOC", loc) +} + +func projAnalyzerTitle(s string) string { + if s == "" { + return s + } + return strings.ToUpper(s[:1]) + s[1:] +} From 2fba194ea1f3d20cc6ff95f34aa81a7b0355f9c0 Mon Sep 17 00:00:00 2001 From: Lakshman Patel Date: Fri, 19 Jun 2026 19:06:02 +0530 Subject: [PATCH 07/20] refactor(scaffold): move built-in templates to scaffold_builtins.go --- internal/engine/scaffold/scaffold.go | 783 ----------------- internal/engine/scaffold/scaffold_builtins.go | 788 ++++++++++++++++++ 2 files changed, 788 insertions(+), 783 deletions(-) create mode 100644 internal/engine/scaffold/scaffold_builtins.go diff --git a/internal/engine/scaffold/scaffold.go b/internal/engine/scaffold/scaffold.go index 463448a3..79519bab 100644 --- a/internal/engine/scaffold/scaffold.go +++ b/internal/engine/scaffold/scaffold.go @@ -57,789 +57,6 @@ func NewScaffolder() *Scaffolder { return s } -func (s *Scaffolder) registerBuiltins() { - s.Templates["go-cli"] = &Template{ - Name: "go-cli", - Description: "Go CLI application with Cobra", - Language: "go", - Framework: "cobra", - Variables: []TemplateVariable{ - {Name: "ProjectName", Description: "Name of the project", Required: true, Type: "string"}, - {Name: "Module", Description: "Go module path", Required: true, Type: "string"}, - {Name: "Author", Description: "Author name", Default: "Developer", Type: "string"}, - {Name: "License", Description: "License type", Default: "MIT", Type: "choice", Choices: []string{"MIT", "Apache-2.0", "BSD-3-Clause"}}, - }, - Files: []TemplateFile{ - { - Path: "{{.ProjectName}}/cmd/main.go", - Content: `package main - -import ( - "fmt" - "os" - - "{{.Module}}/internal/cmd" -) - -func main() { - if err := cmd.Execute(); err != nil { - _, _ = fmt.Fprintf(os.Stderr, "Error: %v\n", err) - os.Exit(1) - } -} -`, - Mode: 0o644, - }, - { - Path: "{{.ProjectName}}/internal/cmd/root.go", - Content: `package cmd - -import ( - "fmt" - "os" -) - -// Execute runs the root command. -func Execute() error { - if len(os.Args) < 2 { - fmt.Println("{{.ProjectName}} - A CLI application") - fmt.Println("Usage: {{.ProjectName}} ") - return nil - } - switch os.Args[1] { - case "version": - fmt.Println("{{.ProjectName}} v0.1.0") - default: - return fmt.Errorf("unknown command: %s", os.Args[1]) - } - return nil -} -`, - Mode: 0o644, - }, - { - Path: "{{.ProjectName}}/go.mod", - Content: `module {{.Module}} - -go 1.21 -`, - Mode: 0o644, - }, - { - Path: "{{.ProjectName}}/Makefile", - Content: `.PHONY: build test clean - -BINARY={{.ProjectName}} - -build: - go build -o bin/$(BINARY) ./cmd/main.go - -test: - go test ./... - -clean: - rm -rf bin/ - -lint: - golangci-lint run ./... -`, - Mode: 0o644, - }, - { - Path: "{{.ProjectName}}/.gitignore", - Content: `bin/ -*.exe -*.exe~ -*.dll -*.so -*.dylib -*.test -*.out -vendor/ -.idea/ -.vscode/ -`, - Mode: 0o644, - }, - { - Path: "{{.ProjectName}}/README.md", - Content: `# {{.ProjectName}} - -{{.ProjectName}} is a CLI application. - -## Installation - -` + "```bash" + ` -go install {{.Module}}/cmd@latest -` + "```" + ` - -## Usage - -` + "```bash" + ` -{{.ProjectName}} version -` + "```" + ` - -## Author - -{{.Author}} - -## License - -{{.License}} -`, - Mode: 0o644, - }, - }, - PostCreate: []string{"cd {{.ProjectName}} && go mod tidy"}, - } - - s.Templates["go-api"] = &Template{ - Name: "go-api", - Description: "Go REST API with net/http", - Language: "go", - Framework: "net/http", - Variables: []TemplateVariable{ - {Name: "ProjectName", Description: "Name of the project", Required: true, Type: "string"}, - {Name: "Module", Description: "Go module path", Required: true, Type: "string"}, - {Name: "Port", Description: "Server port", Default: "8080", Type: "string"}, - {Name: "WithDocker", Description: "Include Dockerfile", Default: "true", Type: "bool"}, - }, - Files: []TemplateFile{ - { - Path: "{{.ProjectName}}/cmd/server/main.go", - Content: `package main - -import ( - "fmt" - "log" - "net/http" - - "{{.Module}}/internal/handler" - "{{.Module}}/internal/middleware" -) - -func main() { - mux := http.NewServeMux() - - mux.HandleFunc("GET /health", handler.Health) - mux.HandleFunc("GET /api/v1/items", handler.ListItems) - mux.HandleFunc("POST /api/v1/items", handler.CreateItem) - - wrapped := middleware.Logger(middleware.Recovery(mux)) - - addr := ":{{.Port}}" - fmt.Printf("Server starting on %s\n", addr) - log.Fatal(http.ListenAndServe(addr, wrapped)) -} -`, - Mode: 0o644, - }, - { - Path: "{{.ProjectName}}/internal/handler/handler.go", - Content: `package handler - -import ( - "encoding/json" - "net/http" -) - -// Health returns service health status. -func Health(w http.ResponseWriter, r *http.Request) { - _ = json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) -} - -// ListItems returns all items. -func ListItems(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode([]string{}) -} - -// CreateItem creates a new item. -func CreateItem(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusCreated) - _ = json.NewEncoder(w).Encode(map[string]string{"status": "created"}) -} -`, - Mode: 0o644, - }, - { - Path: "{{.ProjectName}}/internal/middleware/middleware.go", - Content: `package middleware - -import ( - "log" - "net/http" - "time" -) - -// Logger logs incoming requests. -func Logger(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - start := time.Now() - next.ServeHTTP(w, r) - log.Printf("%s %s %v", r.Method, r.URL.Path, time.Since(start)) - }) -} - -// Recovery recovers from panics. -func Recovery(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - defer func() { - if err := recover(); err != nil { - log.Printf("panic recovered: %v", err) - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - } - }() - next.ServeHTTP(w, r) - }) -} -`, - Mode: 0o644, - }, - { - Path: "{{.ProjectName}}/go.mod", - Content: `module {{.Module}} - -go 1.21 -`, - Mode: 0o644, - }, - { - Path: "{{.ProjectName}}/Dockerfile", - Content: `FROM golang:1.21-alpine AS builder -WORKDIR /app -COPY go.mod go.sum ./ -RUN go mod download -COPY . . -RUN CGO_ENABLED=0 go build -o server ./cmd/server/main.go - -FROM alpine:latest -RUN apk --no-cache add ca-certificates -WORKDIR /root/ -COPY --from=builder /app/server . -EXPOSE {{.Port}} -CMD ["./server"] -`, - Mode: 0o644, - Condition: "{{.WithDocker}}", - }, - { - Path: "{{.ProjectName}}/.gitignore", - Content: `bin/ -*.exe -vendor/ -.env -`, - Mode: 0o644, - }, - { - Path: "{{.ProjectName}}/README.md", - Content: `# {{.ProjectName}} - -A REST API service. - -## Running - -` + "```bash" + ` -go run ./cmd/server/main.go -` + "```" + ` - -## Endpoints - -- GET /health -- GET /api/v1/items -- POST /api/v1/items -`, - Mode: 0o644, - }, - }, - PostCreate: []string{"cd {{.ProjectName}} && go mod tidy"}, - } - - s.Templates["go-lib"] = &Template{ - Name: "go-lib", - Description: "Go library package", - Language: "go", - Framework: "stdlib", - Variables: []TemplateVariable{ - {Name: "ProjectName", Description: "Name of the library", Required: true, Type: "string"}, - {Name: "Module", Description: "Go module path", Required: true, Type: "string"}, - {Name: "PackageName", Description: "Go package name", Required: true, Type: "string"}, - {Name: "WithCI", Description: "Include GitHub Actions CI", Default: "true", Type: "bool"}, - }, - Files: []TemplateFile{ - { - Path: "{{.ProjectName}}/{{.PackageName}}.go", - Content: `// Package {{.PackageName}} provides ... -package {{.PackageName}} - -// Version is the library version. -const Version = "0.1.0" -`, - Mode: 0o644, - }, - { - Path: "{{.ProjectName}}/{{.PackageName}}_test.go", - Content: `package {{.PackageName}} - -import "testing" - -func TestVersion(t *testing.T) { - if Version == "" { - t.Error("Version should not be empty") - } -} -`, - Mode: 0o644, - }, - { - Path: "{{.ProjectName}}/example_test.go", - Content: `package {{.PackageName}}_test - -import ( - "fmt" - - "{{.Module}}" -) - -func Example() { - fmt.Println({{.PackageName}}.Version) - // Output: 0.1.0 -} -`, - Mode: 0o644, - }, - { - Path: "{{.ProjectName}}/go.mod", - Content: `module {{.Module}} - -go 1.21 -`, - Mode: 0o644, - }, - { - Path: "{{.ProjectName}}/.github/workflows/ci.yml", - Content: `name: CI -on: [push, pull_request] -jobs: - test: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-go@v5 - with: - go-version: '1.21' - - run: go test ./... - - run: go vet ./... -`, - Mode: 0o644, - Condition: "{{.WithCI}}", - }, - { - Path: "{{.ProjectName}}/README.md", - Content: `# {{.ProjectName}} - -` + "```go" + ` -import "{{.Module}}" -` + "```" + ` - -## Installation - -` + "```bash" + ` -go get {{.Module}} -` + "```" + ` -`, - Mode: 0o644, - }, - { - Path: "{{.ProjectName}}/.gitignore", - Content: `vendor/ -*.test -`, - Mode: 0o644, - }, - }, - PostCreate: []string{"cd {{.ProjectName}} && go mod tidy"}, - } - - s.Templates["ts-api"] = &Template{ - Name: "ts-api", - Description: "TypeScript API with Express", - Language: "typescript", - Framework: "express", - Variables: []TemplateVariable{ - {Name: "ProjectName", Description: "Name of the project", Required: true, Type: "string"}, - {Name: "Port", Description: "Server port", Default: "3000", Type: "string"}, - {Name: "WithDocker", Description: "Include Dockerfile", Default: "true", Type: "bool"}, - }, - Files: []TemplateFile{ - { - Path: "{{.ProjectName}}/src/index.ts", - Content: `import express from 'express'; - -const app = express(); -const port = process.env.PORT || {{.Port}}; - -app.use(express.json()); - -app.get('/health', (req, res) => { - res.json({ status: 'ok' }); -}); - -app.get('/api/v1/items', (req, res) => { - res.json([]); -}); - -app.listen(port, () => { - console.log(` + "`Server running on port ${port}`" + `); -}); - -export default app; -`, - Mode: 0o644, - }, - { - Path: "{{.ProjectName}}/tsconfig.json", - Content: `{ - "compilerOptions": { - "target": "ES2020", - "module": "commonjs", - "lib": ["ES2020"], - "outDir": "./dist", - "rootDir": "./src", - "strict": true, - "esModuleInterop": true, - "skipLibCheck": true, - "forceConsistentCasingInFileNames": true, - "resolveJsonModule": true, - "declaration": true - }, - "include": ["src/**/*"], - "exclude": ["node_modules", "dist"] -} -`, - Mode: 0o644, - }, - { - Path: "{{.ProjectName}}/package.json", - Content: `{ - "name": "{{.ProjectName}}", - "version": "0.1.0", - "description": "{{.ProjectName}} API", - "main": "dist/index.js", - "scripts": { - "build": "tsc", - "start": "node dist/index.js", - "dev": "ts-node src/index.ts", - "test": "jest" - }, - "dependencies": { - "express": "^4.18.0" - }, - "devDependencies": { - "@types/express": "^4.17.0", - "@types/node": "^20.0.0", - "typescript": "^5.0.0", - "ts-node": "^10.9.0" - } -} -`, - Mode: 0o644, - }, - { - Path: "{{.ProjectName}}/Dockerfile", - Content: `FROM node:20-alpine AS builder -WORKDIR /app -COPY package*.json ./ -RUN npm ci -COPY . . -RUN npm run build - -FROM node:20-alpine -WORKDIR /app -COPY --from=builder /app/dist ./dist -COPY --from=builder /app/package*.json ./ -RUN npm ci --production -EXPOSE {{.Port}} -CMD ["node", "dist/index.js"] -`, - Mode: 0o644, - Condition: "{{.WithDocker}}", - }, - { - Path: "{{.ProjectName}}/.gitignore", - Content: `node_modules/ -dist/ -.env -*.js.map -`, - Mode: 0o644, - }, - { - Path: "{{.ProjectName}}/README.md", - Content: `# {{.ProjectName}} - -TypeScript API service. - -## Development - -` + "```bash" + ` -npm install -npm run dev -` + "```" + ` - -## Build - -` + "```bash" + ` -npm run build -npm start -` + "```" + ` -`, - Mode: 0o644, - }, - }, - PostCreate: []string{"cd {{.ProjectName}} && npm install"}, - } - - s.Templates["python-api"] = &Template{ - Name: "python-api", - Description: "Python API with FastAPI", - Language: "python", - Framework: "fastapi", - Variables: []TemplateVariable{ - {Name: "ProjectName", Description: "Name of the project", Required: true, Type: "string"}, - {Name: "Port", Description: "Server port", Default: "8000", Type: "string"}, - {Name: "WithDocker", Description: "Include Dockerfile", Default: "true", Type: "bool"}, - }, - Files: []TemplateFile{ - { - Path: "{{.ProjectName}}/app/main.py", - Content: `from fastapi import FastAPI - -app = FastAPI(title="{{.ProjectName}}") - - -@app.get("/health") -def health(): - return {"status": "ok"} - - -@app.get("/api/v1/items") -def list_items(): - return [] - - -@app.post("/api/v1/items", status_code=201) -def create_item(item: dict): - return {"status": "created", "item": item} -`, - Mode: 0o644, - }, - { - Path: "{{.ProjectName}}/app/__init__.py", - Content: `"""{{.ProjectName}} application.""" -`, - Mode: 0o644, - }, - { - Path: "{{.ProjectName}}/requirements.txt", - Content: `fastapi>=0.100.0 -uvicorn[standard]>=0.23.0 -pydantic>=2.0.0 -`, - Mode: 0o644, - }, - { - Path: "{{.ProjectName}}/Dockerfile", - Content: `FROM python:3.11-slim -WORKDIR /app -COPY requirements.txt . -RUN pip install --no-cache-dir -r requirements.txt -COPY . . -EXPOSE {{.Port}} -CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "{{.Port}}"] -`, - Mode: 0o644, - Condition: "{{.WithDocker}}", - }, - { - Path: "{{.ProjectName}}/.gitignore", - Content: `__pycache__/ -*.py[cod] -*$py.class -.env -venv/ -.venv/ -dist/ -*.egg-info/ -`, - Mode: 0o644, - }, - { - Path: "{{.ProjectName}}/README.md", - Content: `# {{.ProjectName}} - -Python FastAPI service. - -## Setup - -` + "```bash" + ` -python -m venv venv -source venv/bin/activate -pip install -r requirements.txt -` + "```" + ` - -## Run - -` + "```bash" + ` -uvicorn app.main:app --reload --port {{.Port}} -` + "```" + ` -`, - Mode: 0o644, - }, - }, - PostCreate: []string{"cd {{.ProjectName}} && python -m venv venv"}, - } - - s.Templates["python-cli"] = &Template{ - Name: "python-cli", - Description: "Python CLI with Click", - Language: "python", - Framework: "click", - Variables: []TemplateVariable{ - {Name: "ProjectName", Description: "Name of the project", Required: true, Type: "string"}, - {Name: "Author", Description: "Author name", Default: "Developer", Type: "string"}, - {Name: "WithTests", Description: "Include test directory", Default: "true", Type: "bool"}, - }, - Files: []TemplateFile{ - { - Path: "{{.ProjectName}}/{{.ProjectName}}/cli.py", - Content: `"""CLI entry point for {{.ProjectName}}.""" -import click - - -@click.group() -@click.version_option(version="0.1.0") -def main(): - """{{.ProjectName}} - A command line tool.""" - pass - - -@main.command() -@click.argument("name", default="World") -def hello(name): - """Say hello.""" - click.echo(f"Hello, {name}!") - - -if __name__ == "__main__": - main() -`, - Mode: 0o644, - }, - { - Path: "{{.ProjectName}}/{{.ProjectName}}/__init__.py", - Content: `"""{{.ProjectName}} package.""" -__version__ = "0.1.0" -`, - Mode: 0o644, - }, - { - Path: "{{.ProjectName}}/setup.py", - Content: `from setuptools import setup, find_packages - -setup( - name="{{.ProjectName}}", - version="0.1.0", - author="{{.Author}}", - packages=find_packages(), - install_requires=[ - "click>=8.0.0", - ], - entry_points={ - "console_scripts": [ - "{{.ProjectName}}={{.ProjectName}}.cli:main", - ], - }, - python_requires=">=3.8", -) -`, - Mode: 0o644, - }, - { - Path: "{{.ProjectName}}/tests/__init__.py", - Content: ``, - Mode: 0o644, - Condition: "{{.WithTests}}", - }, - { - Path: "{{.ProjectName}}/tests/test_cli.py", - Content: `"""Tests for CLI.""" -from click.testing import CliRunner -from {{.ProjectName}}.cli import main - - -def test_hello(): - runner = CliRunner() - result = runner.invoke(main, ["hello"]) - assert result.exit_code == 0 - assert "Hello, World!" in result.output - - -def test_hello_name(): - runner = CliRunner() - result = runner.invoke(main, ["hello", "Test"]) - assert result.exit_code == 0 - assert "Hello, Test!" in result.output -`, - Mode: 0o644, - Condition: "{{.WithTests}}", - }, - { - Path: "{{.ProjectName}}/.gitignore", - Content: `__pycache__/ -*.py[cod] -*$py.class -.env -venv/ -.venv/ -dist/ -*.egg-info/ -build/ -`, - Mode: 0o644, - }, - { - Path: "{{.ProjectName}}/README.md", - Content: `# {{.ProjectName}} - -A command line tool. - -## Installation - -` + "```bash" + ` -pip install -e . -` + "```" + ` - -## Usage - -` + "```bash" + ` -{{.ProjectName}} hello -{{.ProjectName}} hello YourName -` + "```" + ` - -## Author - -{{.Author}} -`, - Mode: 0o644, - }, - }, - PostCreate: []string{"cd {{.ProjectName}} && pip install -e ."}, - } -} - // Generate creates a project from a template. func (s *Scaffolder) Generate(templateName string, vars map[string]string, outputDir string) error { s.mu.RLock() diff --git a/internal/engine/scaffold/scaffold_builtins.go b/internal/engine/scaffold/scaffold_builtins.go new file mode 100644 index 00000000..9292320d --- /dev/null +++ b/internal/engine/scaffold/scaffold_builtins.go @@ -0,0 +1,788 @@ +package scaffold + +// This file holds the built-in project templates (data) registered on every new +// Scaffolder. The Scaffolder type and all generation/rendering logic live in +// scaffold.go. + +func (s *Scaffolder) registerBuiltins() { + s.Templates["go-cli"] = &Template{ + Name: "go-cli", + Description: "Go CLI application with Cobra", + Language: "go", + Framework: "cobra", + Variables: []TemplateVariable{ + {Name: "ProjectName", Description: "Name of the project", Required: true, Type: "string"}, + {Name: "Module", Description: "Go module path", Required: true, Type: "string"}, + {Name: "Author", Description: "Author name", Default: "Developer", Type: "string"}, + {Name: "License", Description: "License type", Default: "MIT", Type: "choice", Choices: []string{"MIT", "Apache-2.0", "BSD-3-Clause"}}, + }, + Files: []TemplateFile{ + { + Path: "{{.ProjectName}}/cmd/main.go", + Content: `package main + +import ( + "fmt" + "os" + + "{{.Module}}/internal/cmd" +) + +func main() { + if err := cmd.Execute(); err != nil { + _, _ = fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } +} +`, + Mode: 0o644, + }, + { + Path: "{{.ProjectName}}/internal/cmd/root.go", + Content: `package cmd + +import ( + "fmt" + "os" +) + +// Execute runs the root command. +func Execute() error { + if len(os.Args) < 2 { + fmt.Println("{{.ProjectName}} - A CLI application") + fmt.Println("Usage: {{.ProjectName}} ") + return nil + } + switch os.Args[1] { + case "version": + fmt.Println("{{.ProjectName}} v0.1.0") + default: + return fmt.Errorf("unknown command: %s", os.Args[1]) + } + return nil +} +`, + Mode: 0o644, + }, + { + Path: "{{.ProjectName}}/go.mod", + Content: `module {{.Module}} + +go 1.21 +`, + Mode: 0o644, + }, + { + Path: "{{.ProjectName}}/Makefile", + Content: `.PHONY: build test clean + +BINARY={{.ProjectName}} + +build: + go build -o bin/$(BINARY) ./cmd/main.go + +test: + go test ./... + +clean: + rm -rf bin/ + +lint: + golangci-lint run ./... +`, + Mode: 0o644, + }, + { + Path: "{{.ProjectName}}/.gitignore", + Content: `bin/ +*.exe +*.exe~ +*.dll +*.so +*.dylib +*.test +*.out +vendor/ +.idea/ +.vscode/ +`, + Mode: 0o644, + }, + { + Path: "{{.ProjectName}}/README.md", + Content: `# {{.ProjectName}} + +{{.ProjectName}} is a CLI application. + +## Installation + +` + "```bash" + ` +go install {{.Module}}/cmd@latest +` + "```" + ` + +## Usage + +` + "```bash" + ` +{{.ProjectName}} version +` + "```" + ` + +## Author + +{{.Author}} + +## License + +{{.License}} +`, + Mode: 0o644, + }, + }, + PostCreate: []string{"cd {{.ProjectName}} && go mod tidy"}, + } + + s.Templates["go-api"] = &Template{ + Name: "go-api", + Description: "Go REST API with net/http", + Language: "go", + Framework: "net/http", + Variables: []TemplateVariable{ + {Name: "ProjectName", Description: "Name of the project", Required: true, Type: "string"}, + {Name: "Module", Description: "Go module path", Required: true, Type: "string"}, + {Name: "Port", Description: "Server port", Default: "8080", Type: "string"}, + {Name: "WithDocker", Description: "Include Dockerfile", Default: "true", Type: "bool"}, + }, + Files: []TemplateFile{ + { + Path: "{{.ProjectName}}/cmd/server/main.go", + Content: `package main + +import ( + "fmt" + "log" + "net/http" + + "{{.Module}}/internal/handler" + "{{.Module}}/internal/middleware" +) + +func main() { + mux := http.NewServeMux() + + mux.HandleFunc("GET /health", handler.Health) + mux.HandleFunc("GET /api/v1/items", handler.ListItems) + mux.HandleFunc("POST /api/v1/items", handler.CreateItem) + + wrapped := middleware.Logger(middleware.Recovery(mux)) + + addr := ":{{.Port}}" + fmt.Printf("Server starting on %s\n", addr) + log.Fatal(http.ListenAndServe(addr, wrapped)) +} +`, + Mode: 0o644, + }, + { + Path: "{{.ProjectName}}/internal/handler/handler.go", + Content: `package handler + +import ( + "encoding/json" + "net/http" +) + +// Health returns service health status. +func Health(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) +} + +// ListItems returns all items. +func ListItems(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode([]string{}) +} + +// CreateItem creates a new item. +func CreateItem(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + _ = json.NewEncoder(w).Encode(map[string]string{"status": "created"}) +} +`, + Mode: 0o644, + }, + { + Path: "{{.ProjectName}}/internal/middleware/middleware.go", + Content: `package middleware + +import ( + "log" + "net/http" + "time" +) + +// Logger logs incoming requests. +func Logger(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + next.ServeHTTP(w, r) + log.Printf("%s %s %v", r.Method, r.URL.Path, time.Since(start)) + }) +} + +// Recovery recovers from panics. +func Recovery(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + if err := recover(); err != nil { + log.Printf("panic recovered: %v", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + } + }() + next.ServeHTTP(w, r) + }) +} +`, + Mode: 0o644, + }, + { + Path: "{{.ProjectName}}/go.mod", + Content: `module {{.Module}} + +go 1.21 +`, + Mode: 0o644, + }, + { + Path: "{{.ProjectName}}/Dockerfile", + Content: `FROM golang:1.21-alpine AS builder +WORKDIR /app +COPY go.mod go.sum ./ +RUN go mod download +COPY . . +RUN CGO_ENABLED=0 go build -o server ./cmd/server/main.go + +FROM alpine:latest +RUN apk --no-cache add ca-certificates +WORKDIR /root/ +COPY --from=builder /app/server . +EXPOSE {{.Port}} +CMD ["./server"] +`, + Mode: 0o644, + Condition: "{{.WithDocker}}", + }, + { + Path: "{{.ProjectName}}/.gitignore", + Content: `bin/ +*.exe +vendor/ +.env +`, + Mode: 0o644, + }, + { + Path: "{{.ProjectName}}/README.md", + Content: `# {{.ProjectName}} + +A REST API service. + +## Running + +` + "```bash" + ` +go run ./cmd/server/main.go +` + "```" + ` + +## Endpoints + +- GET /health +- GET /api/v1/items +- POST /api/v1/items +`, + Mode: 0o644, + }, + }, + PostCreate: []string{"cd {{.ProjectName}} && go mod tidy"}, + } + + s.Templates["go-lib"] = &Template{ + Name: "go-lib", + Description: "Go library package", + Language: "go", + Framework: "stdlib", + Variables: []TemplateVariable{ + {Name: "ProjectName", Description: "Name of the library", Required: true, Type: "string"}, + {Name: "Module", Description: "Go module path", Required: true, Type: "string"}, + {Name: "PackageName", Description: "Go package name", Required: true, Type: "string"}, + {Name: "WithCI", Description: "Include GitHub Actions CI", Default: "true", Type: "bool"}, + }, + Files: []TemplateFile{ + { + Path: "{{.ProjectName}}/{{.PackageName}}.go", + Content: `// Package {{.PackageName}} provides ... +package {{.PackageName}} + +// Version is the library version. +const Version = "0.1.0" +`, + Mode: 0o644, + }, + { + Path: "{{.ProjectName}}/{{.PackageName}}_test.go", + Content: `package {{.PackageName}} + +import "testing" + +func TestVersion(t *testing.T) { + if Version == "" { + t.Error("Version should not be empty") + } +} +`, + Mode: 0o644, + }, + { + Path: "{{.ProjectName}}/example_test.go", + Content: `package {{.PackageName}}_test + +import ( + "fmt" + + "{{.Module}}" +) + +func Example() { + fmt.Println({{.PackageName}}.Version) + // Output: 0.1.0 +} +`, + Mode: 0o644, + }, + { + Path: "{{.ProjectName}}/go.mod", + Content: `module {{.Module}} + +go 1.21 +`, + Mode: 0o644, + }, + { + Path: "{{.ProjectName}}/.github/workflows/ci.yml", + Content: `name: CI +on: [push, pull_request] +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version: '1.21' + - run: go test ./... + - run: go vet ./... +`, + Mode: 0o644, + Condition: "{{.WithCI}}", + }, + { + Path: "{{.ProjectName}}/README.md", + Content: `# {{.ProjectName}} + +` + "```go" + ` +import "{{.Module}}" +` + "```" + ` + +## Installation + +` + "```bash" + ` +go get {{.Module}} +` + "```" + ` +`, + Mode: 0o644, + }, + { + Path: "{{.ProjectName}}/.gitignore", + Content: `vendor/ +*.test +`, + Mode: 0o644, + }, + }, + PostCreate: []string{"cd {{.ProjectName}} && go mod tidy"}, + } + + s.Templates["ts-api"] = &Template{ + Name: "ts-api", + Description: "TypeScript API with Express", + Language: "typescript", + Framework: "express", + Variables: []TemplateVariable{ + {Name: "ProjectName", Description: "Name of the project", Required: true, Type: "string"}, + {Name: "Port", Description: "Server port", Default: "3000", Type: "string"}, + {Name: "WithDocker", Description: "Include Dockerfile", Default: "true", Type: "bool"}, + }, + Files: []TemplateFile{ + { + Path: "{{.ProjectName}}/src/index.ts", + Content: `import express from 'express'; + +const app = express(); +const port = process.env.PORT || {{.Port}}; + +app.use(express.json()); + +app.get('/health', (req, res) => { + res.json({ status: 'ok' }); +}); + +app.get('/api/v1/items', (req, res) => { + res.json([]); +}); + +app.listen(port, () => { + console.log(` + "`Server running on port ${port}`" + `); +}); + +export default app; +`, + Mode: 0o644, + }, + { + Path: "{{.ProjectName}}/tsconfig.json", + Content: `{ + "compilerOptions": { + "target": "ES2020", + "module": "commonjs", + "lib": ["ES2020"], + "outDir": "./dist", + "rootDir": "./src", + "strict": true, + "esModuleInterop": true, + "skipLibCheck": true, + "forceConsistentCasingInFileNames": true, + "resolveJsonModule": true, + "declaration": true + }, + "include": ["src/**/*"], + "exclude": ["node_modules", "dist"] +} +`, + Mode: 0o644, + }, + { + Path: "{{.ProjectName}}/package.json", + Content: `{ + "name": "{{.ProjectName}}", + "version": "0.1.0", + "description": "{{.ProjectName}} API", + "main": "dist/index.js", + "scripts": { + "build": "tsc", + "start": "node dist/index.js", + "dev": "ts-node src/index.ts", + "test": "jest" + }, + "dependencies": { + "express": "^4.18.0" + }, + "devDependencies": { + "@types/express": "^4.17.0", + "@types/node": "^20.0.0", + "typescript": "^5.0.0", + "ts-node": "^10.9.0" + } +} +`, + Mode: 0o644, + }, + { + Path: "{{.ProjectName}}/Dockerfile", + Content: `FROM node:20-alpine AS builder +WORKDIR /app +COPY package*.json ./ +RUN npm ci +COPY . . +RUN npm run build + +FROM node:20-alpine +WORKDIR /app +COPY --from=builder /app/dist ./dist +COPY --from=builder /app/package*.json ./ +RUN npm ci --production +EXPOSE {{.Port}} +CMD ["node", "dist/index.js"] +`, + Mode: 0o644, + Condition: "{{.WithDocker}}", + }, + { + Path: "{{.ProjectName}}/.gitignore", + Content: `node_modules/ +dist/ +.env +*.js.map +`, + Mode: 0o644, + }, + { + Path: "{{.ProjectName}}/README.md", + Content: `# {{.ProjectName}} + +TypeScript API service. + +## Development + +` + "```bash" + ` +npm install +npm run dev +` + "```" + ` + +## Build + +` + "```bash" + ` +npm run build +npm start +` + "```" + ` +`, + Mode: 0o644, + }, + }, + PostCreate: []string{"cd {{.ProjectName}} && npm install"}, + } + + s.Templates["python-api"] = &Template{ + Name: "python-api", + Description: "Python API with FastAPI", + Language: "python", + Framework: "fastapi", + Variables: []TemplateVariable{ + {Name: "ProjectName", Description: "Name of the project", Required: true, Type: "string"}, + {Name: "Port", Description: "Server port", Default: "8000", Type: "string"}, + {Name: "WithDocker", Description: "Include Dockerfile", Default: "true", Type: "bool"}, + }, + Files: []TemplateFile{ + { + Path: "{{.ProjectName}}/app/main.py", + Content: `from fastapi import FastAPI + +app = FastAPI(title="{{.ProjectName}}") + + +@app.get("/health") +def health(): + return {"status": "ok"} + + +@app.get("/api/v1/items") +def list_items(): + return [] + + +@app.post("/api/v1/items", status_code=201) +def create_item(item: dict): + return {"status": "created", "item": item} +`, + Mode: 0o644, + }, + { + Path: "{{.ProjectName}}/app/__init__.py", + Content: `"""{{.ProjectName}} application.""" +`, + Mode: 0o644, + }, + { + Path: "{{.ProjectName}}/requirements.txt", + Content: `fastapi>=0.100.0 +uvicorn[standard]>=0.23.0 +pydantic>=2.0.0 +`, + Mode: 0o644, + }, + { + Path: "{{.ProjectName}}/Dockerfile", + Content: `FROM python:3.11-slim +WORKDIR /app +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt +COPY . . +EXPOSE {{.Port}} +CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "{{.Port}}"] +`, + Mode: 0o644, + Condition: "{{.WithDocker}}", + }, + { + Path: "{{.ProjectName}}/.gitignore", + Content: `__pycache__/ +*.py[cod] +*$py.class +.env +venv/ +.venv/ +dist/ +*.egg-info/ +`, + Mode: 0o644, + }, + { + Path: "{{.ProjectName}}/README.md", + Content: `# {{.ProjectName}} + +Python FastAPI service. + +## Setup + +` + "```bash" + ` +python -m venv venv +source venv/bin/activate +pip install -r requirements.txt +` + "```" + ` + +## Run + +` + "```bash" + ` +uvicorn app.main:app --reload --port {{.Port}} +` + "```" + ` +`, + Mode: 0o644, + }, + }, + PostCreate: []string{"cd {{.ProjectName}} && python -m venv venv"}, + } + + s.Templates["python-cli"] = &Template{ + Name: "python-cli", + Description: "Python CLI with Click", + Language: "python", + Framework: "click", + Variables: []TemplateVariable{ + {Name: "ProjectName", Description: "Name of the project", Required: true, Type: "string"}, + {Name: "Author", Description: "Author name", Default: "Developer", Type: "string"}, + {Name: "WithTests", Description: "Include test directory", Default: "true", Type: "bool"}, + }, + Files: []TemplateFile{ + { + Path: "{{.ProjectName}}/{{.ProjectName}}/cli.py", + Content: `"""CLI entry point for {{.ProjectName}}.""" +import click + + +@click.group() +@click.version_option(version="0.1.0") +def main(): + """{{.ProjectName}} - A command line tool.""" + pass + + +@main.command() +@click.argument("name", default="World") +def hello(name): + """Say hello.""" + click.echo(f"Hello, {name}!") + + +if __name__ == "__main__": + main() +`, + Mode: 0o644, + }, + { + Path: "{{.ProjectName}}/{{.ProjectName}}/__init__.py", + Content: `"""{{.ProjectName}} package.""" +__version__ = "0.1.0" +`, + Mode: 0o644, + }, + { + Path: "{{.ProjectName}}/setup.py", + Content: `from setuptools import setup, find_packages + +setup( + name="{{.ProjectName}}", + version="0.1.0", + author="{{.Author}}", + packages=find_packages(), + install_requires=[ + "click>=8.0.0", + ], + entry_points={ + "console_scripts": [ + "{{.ProjectName}}={{.ProjectName}}.cli:main", + ], + }, + python_requires=">=3.8", +) +`, + Mode: 0o644, + }, + { + Path: "{{.ProjectName}}/tests/__init__.py", + Content: ``, + Mode: 0o644, + Condition: "{{.WithTests}}", + }, + { + Path: "{{.ProjectName}}/tests/test_cli.py", + Content: `"""Tests for CLI.""" +from click.testing import CliRunner +from {{.ProjectName}}.cli import main + + +def test_hello(): + runner = CliRunner() + result = runner.invoke(main, ["hello"]) + assert result.exit_code == 0 + assert "Hello, World!" in result.output + + +def test_hello_name(): + runner = CliRunner() + result = runner.invoke(main, ["hello", "Test"]) + assert result.exit_code == 0 + assert "Hello, Test!" in result.output +`, + Mode: 0o644, + Condition: "{{.WithTests}}", + }, + { + Path: "{{.ProjectName}}/.gitignore", + Content: `__pycache__/ +*.py[cod] +*$py.class +.env +venv/ +.venv/ +dist/ +*.egg-info/ +build/ +`, + Mode: 0o644, + }, + { + Path: "{{.ProjectName}}/README.md", + Content: `# {{.ProjectName}} + +A command line tool. + +## Installation + +` + "```bash" + ` +pip install -e . +` + "```" + ` + +## Usage + +` + "```bash" + ` +{{.ProjectName}} hello +{{.ProjectName}} hello YourName +` + "```" + ` + +## Author + +{{.Author}} +`, + Mode: 0o644, + }, + }, + PostCreate: []string{"cd {{.ProjectName}} && pip install -e ."}, + } +} From bf25f0f2a1bc613b4cf4bc612945043385b31f44 Mon Sep 17 00:00:00 2001 From: Lakshman Patel Date: Fri, 19 Jun 2026 19:09:45 +0530 Subject: [PATCH 08/20] refactor(tool): move built-in codegen templates to codegen_builtins.go --- internal/tool/codegen.go | 851 ----------------------------- internal/tool/codegen_builtins.go | 856 ++++++++++++++++++++++++++++++ 2 files changed, 856 insertions(+), 851 deletions(-) create mode 100644 internal/tool/codegen_builtins.go diff --git a/internal/tool/codegen.go b/internal/tool/codegen.go index 06385cbb..8033821b 100644 --- a/internal/tool/codegen.go +++ b/internal/tool/codegen.go @@ -48,857 +48,6 @@ func NewCodeGenerator() *CodeGenerator { return cg } -func (cg *CodeGenerator) registerBuiltins() { - // Go templates - cg.Templates["go-handler"] = &CodeTemplate{ - Name: "go-handler", - Description: "HTTP handler function with request parsing, validation, and response", - Language: "go", - Template: `package {{.Package}} - -import ( - "encoding/json" - "net/http" -) - -// {{.Name}}Request represents the request body for {{.Name}}. -type {{.Name}}Request struct { - // TODO: define request fields -} - -// {{.Name}}Response represents the response body for {{.Name}}. -type {{.Name}}Response struct { - // TODO: define response fields -} - -// {{.Name}}Handler handles {{.Method}} requests for {{.Path}}. -func {{.Name}}Handler(w http.ResponseWriter, r *http.Request) { - if r.Method != http.Method{{.Method}} { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return - } - - var req {{.Name}}Request - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "invalid request body", http.StatusBadRequest) - return - } - defer r.Body.Close() - - // TODO: implement handler logic - - resp := {{.Name}}Response{} - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(resp) -} -`, - Variables: []TemplateVar{ - {Name: "Package", Description: "Go package name", Required: true, Default: "handlers"}, - {Name: "Name", Description: "Handler name (PascalCase)", Required: true, Default: ""}, - {Name: "Method", Description: "HTTP method (Get, Post, Put, Delete)", Required: false, Default: "Post"}, - {Name: "Path", Description: "URL path for the endpoint", Required: false, Default: "/api/resource"}, - }, - Output: "{{.Name | lower}}_handler.go", - } - - cg.Templates["go-middleware"] = &CodeTemplate{ - Name: "go-middleware", - Description: "HTTP middleware with next handler chaining", - Language: "go", - Template: `package {{.Package}} - -import ( - "log" - "net/http" - "time" -) - -// {{.Name}} is a middleware that {{.Description}}. -func {{.Name}}(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - start := time.Now() - - // Pre-processing - log.Printf("[{{.Name}}] %s %s started", r.Method, r.URL.Path) - - // TODO: implement middleware logic - - // Call next handler - next.ServeHTTP(w, r) - - // Post-processing - duration := time.Since(start) - log.Printf("[{{.Name}}] %s %s completed in %v", r.Method, r.URL.Path, duration) - }) -} -`, - Variables: []TemplateVar{ - {Name: "Package", Description: "Go package name", Required: true, Default: "middleware"}, - {Name: "Name", Description: "Middleware function name", Required: true, Default: ""}, - {Name: "Description", Description: "What the middleware does", Required: false, Default: "processes requests"}, - }, - Output: "{{.Name | lower}}.go", - } - - cg.Templates["go-crud"] = &CodeTemplate{ - Name: "go-crud", - Description: "Full CRUD functions for a resource (Create, Get, List, Update, Delete)", - Language: "go", - Template: `package {{.Package}} - -import ( - "encoding/json" - "fmt" - "net/http" - "sync" -) - -// {{.Resource}} represents the {{.Resource}} entity. -type {{.Resource}} struct { - ID string ` + "`json:\"id\"`" + ` - Name string ` + "`json:\"name\"`" + ` - // TODO: add fields -} - -// {{.Resource}}Store manages {{.Resource}} persistence. -type {{.Resource}}Store struct { - mu sync.RWMutex - items map[string]*{{.Resource}} -} - -// New{{.Resource}}Store creates a new store. -func New{{.Resource}}Store() *{{.Resource}}Store { - return &{{.Resource}}Store{items: make(map[string]*{{.Resource}})} -} - -// Create{{.Resource}} adds a new {{.Resource}}. -func (s *{{.Resource}}Store) Create{{.Resource}}(item *{{.Resource}}) error { - s.mu.Lock() - defer s.mu.Unlock() - if _, exists := s.items[item.ID]; exists { - return fmt.Errorf("{{.Resource}} with ID %s already exists", item.ID) - } - s.items[item.ID] = item - return nil -} - -// Get{{.Resource}} retrieves a {{.Resource}} by ID. -func (s *{{.Resource}}Store) Get{{.Resource}}(id string) (*{{.Resource}}, error) { - s.mu.RLock() - defer s.mu.RUnlock() - item, ok := s.items[id] - if !ok { - return nil, fmt.Errorf("{{.Resource}} with ID %s not found", id) - } - return item, nil -} - -// List{{.Resource}}s returns all {{.Resource}} items. -func (s *{{.Resource}}Store) List{{.Resource}}s() []*{{.Resource}} { - s.mu.RLock() - defer s.mu.RUnlock() - result := make([]*{{.Resource}}, 0, len(s.items)) - for _, item := range s.items { - result = append(result, item) - } - return result -} - -// Update{{.Resource}} updates an existing {{.Resource}}. -func (s *{{.Resource}}Store) Update{{.Resource}}(item *{{.Resource}}) error { - s.mu.Lock() - defer s.mu.Unlock() - if _, exists := s.items[item.ID]; !exists { - return fmt.Errorf("{{.Resource}} with ID %s not found", item.ID) - } - s.items[item.ID] = item - return nil -} - -// Delete{{.Resource}} removes a {{.Resource}} by ID. -func (s *{{.Resource}}Store) Delete{{.Resource}}(id string) error { - s.mu.Lock() - defer s.mu.Unlock() - if _, exists := s.items[id]; !exists { - return fmt.Errorf("{{.Resource}} with ID %s not found", id) - } - delete(s.items, id) - return nil -} - -// Handle{{.Resource}}s returns an HTTP handler for {{.Resource}} CRUD operations. -func (s *{{.Resource}}Store) Handle{{.Resource}}s(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - switch r.Method { - case http.MethodGet: - items := s.List{{.Resource}}s() - _ = json.NewEncoder(w).Encode(items) - case http.MethodPost: - var item {{.Resource}} - if err := json.NewDecoder(r.Body).Decode(&item); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - if err := s.Create{{.Resource}}(&item); err != nil { - http.Error(w, err.Error(), http.StatusConflict) - return - } - w.WriteHeader(http.StatusCreated) - _ = json.NewEncoder(w).Encode(item) - case http.MethodPut: - var item {{.Resource}} - if err := json.NewDecoder(r.Body).Decode(&item); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - if err := s.Update{{.Resource}}(&item); err != nil { - http.Error(w, err.Error(), http.StatusNotFound) - return - } - _ = json.NewEncoder(w).Encode(item) - case http.MethodDelete: - id := r.URL.Query().Get("id") - if err := s.Delete{{.Resource}}(id); err != nil { - http.Error(w, err.Error(), http.StatusNotFound) - return - } - w.WriteHeader(http.StatusNoContent) - default: - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - } -} -`, - Variables: []TemplateVar{ - {Name: "Package", Description: "Go package name", Required: true, Default: "models"}, - {Name: "Resource", Description: "Resource name (PascalCase)", Required: true, Default: ""}, - }, - Output: "{{.Resource | lower}}_crud.go", - } - - cg.Templates["go-test-table"] = &CodeTemplate{ - Name: "go-test-table", - Description: "Table-driven test with subtests", - Language: "go", - Template: `package {{.Package}} - -import ( - "testing" -) - -func Test{{.Function}}(t *testing.T) { - tests := []struct { - name string - input string - want string - wantErr bool - }{ - { - name: "valid input", - input: "hello", - want: "expected", - wantErr: false, - }, - { - name: "empty input", - input: "", - want: "", - wantErr: true, - }, - // TODO: add more test cases - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := {{.Function}}(tt.input) - if (err != nil) != tt.wantErr { - t.Errorf("{{.Function}}() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("{{.Function}}() = %v, want %v", got, tt.want) - } - }) - } -} -`, - Variables: []TemplateVar{ - {Name: "Package", Description: "Go package name", Required: true, Default: ""}, - {Name: "Function", Description: "Function name to test", Required: true, Default: ""}, - }, - Output: "{{.Function | lower}}_test.go", - } - - cg.Templates["go-interface"] = &CodeTemplate{ - Name: "go-interface", - Description: "Interface definition with mock implementation", - Language: "go", - Template: `package {{.Package}} - -import ( - "context" - "sync" -) - -// {{.Name}} defines the interface for {{.Description}}. -type {{.Name}} interface { - Get(ctx context.Context, id string) (interface{}, error) - List(ctx context.Context) ([]interface{}, error) - Create(ctx context.Context, item interface{}) error - Update(ctx context.Context, id string, item interface{}) error - Delete(ctx context.Context, id string) error -} - -// Mock{{.Name}} is a test double for {{.Name}}. -type Mock{{.Name}} struct { - mu sync.Mutex - GetFunc func(ctx context.Context, id string) (interface{}, error) - ListFunc func(ctx context.Context) ([]interface{}, error) - CreateFunc func(ctx context.Context, item interface{}) error - UpdateFunc func(ctx context.Context, id string, item interface{}) error - DeleteFunc func(ctx context.Context, id string) error - Calls []string -} - -func (m *Mock{{.Name}}) Get(ctx context.Context, id string) (interface{}, error) { - m.mu.Lock() - m.Calls = append(m.Calls, "Get") - m.mu.Unlock() - if m.GetFunc != nil { - return m.GetFunc(ctx, id) - } - return nil, nil -} - -func (m *Mock{{.Name}}) List(ctx context.Context) ([]interface{}, error) { - m.mu.Lock() - m.Calls = append(m.Calls, "List") - m.mu.Unlock() - if m.ListFunc != nil { - return m.ListFunc(ctx) - } - return nil, nil -} - -func (m *Mock{{.Name}}) Create(ctx context.Context, item interface{}) error { - m.mu.Lock() - m.Calls = append(m.Calls, "Create") - m.mu.Unlock() - if m.CreateFunc != nil { - return m.CreateFunc(ctx, item) - } - return nil -} - -func (m *Mock{{.Name}}) Update(ctx context.Context, id string, item interface{}) error { - m.mu.Lock() - m.Calls = append(m.Calls, "Update") - m.mu.Unlock() - if m.UpdateFunc != nil { - return m.UpdateFunc(ctx, id, item) - } - return nil -} - -func (m *Mock{{.Name}}) Delete(ctx context.Context, id string) error { - m.mu.Lock() - m.Calls = append(m.Calls, "Delete") - m.mu.Unlock() - if m.DeleteFunc != nil { - return m.DeleteFunc(ctx, id) - } - return nil -} -`, - Variables: []TemplateVar{ - {Name: "Package", Description: "Go package name", Required: true, Default: ""}, - {Name: "Name", Description: "Interface name (PascalCase)", Required: true, Default: ""}, - {Name: "Description", Description: "What the interface represents", Required: false, Default: "a service"}, - }, - Output: "{{.Name | lower}}.go", - } - - cg.Templates["go-errors"] = &CodeTemplate{ - Name: "go-errors", - Description: "Custom error type with constructors", - Language: "go", - Template: `package {{.Package}} - -import "fmt" - -// {{.Name}}Error represents an error in the {{.Domain}} domain. -type {{.Name}}Error struct { - Code string - Message string - Err error -} - -func (e *{{.Name}}Error) Error() string { - if e.Err != nil { - return fmt.Sprintf("[%s] %s: %v", e.Code, e.Message, e.Err) - } - return fmt.Sprintf("[%s] %s", e.Code, e.Message) -} - -func (e *{{.Name}}Error) Unwrap() error { - return e.Err -} - -// New{{.Name}}Error creates a new {{.Name}}Error. -func New{{.Name}}Error(code, message string) *{{.Name}}Error { - return &{{.Name}}Error{Code: code, Message: message} -} - -// Wrap{{.Name}}Error wraps an existing error with a {{.Name}}Error. -func Wrap{{.Name}}Error(code, message string, err error) *{{.Name}}Error { - return &{{.Name}}Error{Code: code, Message: message, Err: err} -} - -// ErrNotFound creates a not-found error. -func ErrNotFound(resource, id string) *{{.Name}}Error { - return New{{.Name}}Error("NOT_FOUND", fmt.Sprintf("%s with ID %s not found", resource, id)) -} - -// ErrValidation creates a validation error. -func ErrValidation(field, reason string) *{{.Name}}Error { - return New{{.Name}}Error("VALIDATION", fmt.Sprintf("field %s: %s", field, reason)) -} - -// ErrInternal creates an internal error wrapping the cause. -func ErrInternal(message string, err error) *{{.Name}}Error { - return Wrap{{.Name}}Error("INTERNAL", message, err) -} - -// Is{{.Name}}Error checks if an error is a {{.Name}}Error with the given code. -func Is{{.Name}}Error(err error, code string) bool { - if e, ok := err.(*{{.Name}}Error); ok { - return e.Code == code - } - return false -} -`, - Variables: []TemplateVar{ - {Name: "Package", Description: "Go package name", Required: true, Default: ""}, - {Name: "Name", Description: "Error type prefix (PascalCase)", Required: true, Default: ""}, - {Name: "Domain", Description: "Domain this error belongs to", Required: false, Default: "application"}, - }, - Output: "{{.Name | lower}}_errors.go", - } - - cg.Templates["go-config"] = &CodeTemplate{ - Name: "go-config", - Description: "Config struct with environment variable loading and validation", - Language: "go", - Template: `package {{.Package}} - -import ( - "fmt" - "os" - "strconv" - "strings" -) - -// {{.Name}}Config holds configuration for {{.Description}}. -type {{.Name}}Config struct { - Host string - Port int - Debug bool - LogLevel string - // TODO: add more config fields -} - -// Default{{.Name}}Config returns the default configuration. -func Default{{.Name}}Config() *{{.Name}}Config { - return &{{.Name}}Config{ - Host: "localhost", - Port: 8080, - Debug: false, - LogLevel: "info", - } -} - -// Load{{.Name}}Config loads configuration from environment variables. -// Variables are prefixed with {{.Prefix}}_. -func Load{{.Name}}Config() (*{{.Name}}Config, error) { - cfg := Default{{.Name}}Config() - - if v := os.Getenv("{{.Prefix}}_HOST"); v != "" { - cfg.Host = v - } - if v := os.Getenv("{{.Prefix}}_PORT"); v != "" { - port, err := strconv.Atoi(v) - if err != nil { - return nil, fmt.Errorf("invalid {{.Prefix}}_PORT: %w", err) - } - cfg.Port = port - } - if v := os.Getenv("{{.Prefix}}_DEBUG"); v != "" { - cfg.Debug = strings.ToLower(v) == "true" || v == "1" - } - if v := os.Getenv("{{.Prefix}}_LOG_LEVEL"); v != "" { - cfg.LogLevel = v - } - - if err := cfg.Validate(); err != nil { - return nil, err - } - return cfg, nil -} - -// Validate checks that the configuration is valid. -func (c *{{.Name}}Config) Validate() error { - if c.Host == "" { - return fmt.Errorf("host must not be empty") - } - if c.Port < 1 || c.Port > 65535 { - return fmt.Errorf("port must be between 1 and 65535, got %d", c.Port) - } - validLevels := map[string]bool{"debug": true, "info": true, "warn": true, "error": true} - if !validLevels[strings.ToLower(c.LogLevel)] { - return fmt.Errorf("invalid log level: %s (must be debug, info, warn, or error)", c.LogLevel) - } - return nil -} - -// Address returns the host:port address string. -func (c *{{.Name}}Config) Address() string { - return fmt.Sprintf("%s:%d", c.Host, c.Port) -} -`, - Variables: []TemplateVar{ - {Name: "Package", Description: "Go package name", Required: true, Default: "config"}, - {Name: "Name", Description: "Config name (PascalCase)", Required: true, Default: ""}, - {Name: "Prefix", Description: "Environment variable prefix (UPPER_CASE)", Required: true, Default: "APP"}, - {Name: "Description", Description: "What this config is for", Required: false, Default: "the application"}, - }, - Output: "{{.Name | lower}}_config.go", - } - - // Python templates - cg.Templates["py-fastapi-endpoint"] = &CodeTemplate{ - Name: "py-fastapi-endpoint", - Description: "FastAPI route with Pydantic model", - Language: "python", - Template: `from fastapi import APIRouter, HTTPException -from pydantic import BaseModel, Field -from typing import Optional - -router = APIRouter(prefix="/{{.Prefix}}", tags=["{{.Tag}}"]) - - -class {{.Name}}Request(BaseModel): - """Request model for {{.Name}}.""" - name: str = Field(..., description="Name field") - # TODO: add request fields - - -class {{.Name}}Response(BaseModel): - """Response model for {{.Name}}.""" - id: str - name: str - # TODO: add response fields - - -@router.post("/", response_model={{.Name}}Response, status_code=201) -async def create_{{.NameLower}}(request: {{.Name}}Request) -> {{.Name}}Response: - """Create a new {{.Name}}.""" - # TODO: implement creation logic - return {{.Name}}Response(id="generated-id", name=request.name) - - -@router.get("/{item_id}", response_model={{.Name}}Response) -async def get_{{.NameLower}}(item_id: str) -> {{.Name}}Response: - """Get a {{.Name}} by ID.""" - # TODO: implement retrieval logic - raise HTTPException(status_code=404, detail="{{.Name}} not found") - - -@router.put("/{item_id}", response_model={{.Name}}Response) -async def update_{{.NameLower}}(item_id: str, request: {{.Name}}Request) -> {{.Name}}Response: - """Update a {{.Name}}.""" - # TODO: implement update logic - return {{.Name}}Response(id=item_id, name=request.name) - - -@router.delete("/{item_id}", status_code=204) -async def delete_{{.NameLower}}(item_id: str) -> None: - """Delete a {{.Name}}.""" - # TODO: implement deletion logic - pass -`, - Variables: []TemplateVar{ - {Name: "Name", Description: "Resource name (PascalCase)", Required: true, Default: ""}, - {Name: "NameLower", Description: "Resource name (lowercase)", Required: true, Default: ""}, - {Name: "Prefix", Description: "URL prefix", Required: false, Default: "api"}, - {Name: "Tag", Description: "OpenAPI tag", Required: false, Default: "default"}, - }, - Output: "{{.NameLower}}_router.py", - } - - cg.Templates["py-test-class"] = &CodeTemplate{ - Name: "py-test-class", - Description: "Pytest test class with setup/teardown", - Language: "python", - Template: `import pytest - - -class Test{{.Name}}: - """Tests for {{.Name}}.""" - - def setup_method(self): - """Set up test fixtures.""" - # TODO: initialize test fixtures - self.subject = None - - def teardown_method(self): - """Clean up after tests.""" - # TODO: clean up resources - pass - - def test_{{.MethodUnderTest}}_with_valid_input(self): - """Test {{.MethodUnderTest}} with valid input.""" - # Arrange - expected = None # TODO: set expected value - - # Act - result = self.subject.{{.MethodUnderTest}}() - - # Assert - assert result == expected - - def test_{{.MethodUnderTest}}_with_invalid_input(self): - """Test {{.MethodUnderTest}} raises on invalid input.""" - with pytest.raises(ValueError): - self.subject.{{.MethodUnderTest}}(None) - - def test_{{.MethodUnderTest}}_edge_case(self): - """Test {{.MethodUnderTest}} handles edge cases.""" - # TODO: implement edge case test - pass -`, - Variables: []TemplateVar{ - {Name: "Name", Description: "Class under test (PascalCase)", Required: true, Default: ""}, - {Name: "MethodUnderTest", Description: "Primary method to test", Required: true, Default: "execute"}, - }, - Output: "test_{{.Name | lower}}.py", - } - - cg.Templates["py-dataclass"] = &CodeTemplate{ - Name: "py-dataclass", - Description: "Dataclass with validation", - Language: "python", - Template: `from dataclasses import dataclass, field -from typing import Optional, List - - -@dataclass -class {{.Name}}: - """{{.Description}}""" - - name: str - value: int = 0 - tags: List[str] = field(default_factory=list) - metadata: Optional[str] = None - - def __post_init__(self): - """Validate fields after initialization.""" - if not self.name: - raise ValueError("name must not be empty") - if self.value < 0: - raise ValueError("value must be non-negative") - # TODO: add more validation - - def to_dict(self) -> dict: - """Convert to dictionary.""" - return { - "name": self.name, - "value": self.value, - "tags": list(self.tags), - "metadata": self.metadata, - } - - @classmethod - def from_dict(cls, data: dict) -> "{{.Name}}": - """Create instance from dictionary.""" - return cls( - name=data["name"], - value=data.get("value", 0), - tags=data.get("tags", []), - metadata=data.get("metadata"), - ) -`, - Variables: []TemplateVar{ - {Name: "Name", Description: "Class name (PascalCase)", Required: true, Default: ""}, - {Name: "Description", Description: "Class description", Required: false, Default: "A data model"}, - }, - Output: "{{.Name | lower}}.py", - } - - // TypeScript templates - cg.Templates["ts-react-component"] = &CodeTemplate{ - Name: "ts-react-component", - Description: "Functional React component with props interface", - Language: "typescript", - Template: `import React from 'react'; - -interface {{.Name}}Props { - title: string; - className?: string; - children?: React.ReactNode; - onClick?: () => void; -} - -/** - * {{.Description}} - */ -export const {{.Name}}: React.FC<{{.Name}}Props> = ({ - title, - className = '', - children, - onClick, -}) => { - return ( -
-

{title}

- {children} -
- ); -}; - -export default {{.Name}}; -`, - Variables: []TemplateVar{ - {Name: "Name", Description: "Component name (PascalCase)", Required: true, Default: ""}, - {Name: "Description", Description: "Component description", Required: false, Default: "A React component"}, - }, - Output: "{{.Name}}.tsx", - } - - cg.Templates["ts-express-router"] = &CodeTemplate{ - Name: "ts-express-router", - Description: "Express router with middleware", - Language: "typescript", - Template: `import { Router, Request, Response, NextFunction } from 'express'; - -const router = Router(); - -// Middleware for this router -function validate{{.Name}}(req: Request, res: Response, next: NextFunction): void { - // TODO: implement validation - next(); -} - -// GET /{{.Path}} -router.get('/', async (req: Request, res: Response) => { - try { - // TODO: implement list - res.json({ items: [] }); - } catch (error) { - res.status(500).json({ error: 'Internal server error' }); - } -}); - -// GET /{{.Path}}/:id -router.get('/:id', async (req: Request, res: Response) => { - try { - const { id } = req.params; - // TODO: implement get by id - res.json({ id }); - } catch (error) { - res.status(500).json({ error: 'Internal server error' }); - } -}); - -// POST /{{.Path}} -router.post('/', validate{{.Name}}, async (req: Request, res: Response) => { - try { - // TODO: implement create - res.status(201).json({ id: 'new-id', ...req.body }); - } catch (error) { - res.status(500).json({ error: 'Internal server error' }); - } -}); - -// PUT /{{.Path}}/:id -router.put('/:id', validate{{.Name}}, async (req: Request, res: Response) => { - try { - const { id } = req.params; - // TODO: implement update - res.json({ id, ...req.body }); - } catch (error) { - res.status(500).json({ error: 'Internal server error' }); - } -}); - -// DELETE /{{.Path}}/:id -router.delete('/:id', async (req: Request, res: Response) => { - try { - const { id } = req.params; - // TODO: implement delete - res.status(204).send(); - } catch (error) { - res.status(500).json({ error: 'Internal server error' }); - } -}); - -export default router; -`, - Variables: []TemplateVar{ - {Name: "Name", Description: "Resource name (PascalCase)", Required: true, Default: ""}, - {Name: "Path", Description: "Route path", Required: false, Default: "resources"}, - }, - Output: "{{.Name | lower}}.router.ts", - } - - cg.Templates["ts-test-describe"] = &CodeTemplate{ - Name: "ts-test-describe", - Description: "Jest/Vitest describe block with test cases", - Language: "typescript", - Template: `import { describe, it, expect, beforeEach, afterEach } from 'vitest'; - -describe('{{.Name}}', () => { - let subject: any; - - beforeEach(() => { - // TODO: set up test fixtures - subject = null; - }); - - afterEach(() => { - // TODO: clean up - }); - - describe('{{.Method}}', () => { - it('should handle valid input', () => { - // Arrange - const input = {}; - - // Act - const result = subject.{{.Method}}(input); - - // Assert - expect(result).toBeDefined(); - }); - - it('should throw on invalid input', () => { - expect(() => subject.{{.Method}}(null)).toThrow(); - }); - - it('should handle edge cases', () => { - // TODO: implement edge case test - expect(true).toBe(true); - }); - }); -}); -`, - Variables: []TemplateVar{ - {Name: "Name", Description: "Module/class under test", Required: true, Default: ""}, - {Name: "Method", Description: "Method being tested", Required: true, Default: "execute"}, - }, - Output: "{{.Name | lower}}.test.ts", - } -} - // Generate renders a template with the given variables. func (cg *CodeGenerator) Generate(templateName string, vars map[string]string) (string, error) { cg.mu.RLock() diff --git a/internal/tool/codegen_builtins.go b/internal/tool/codegen_builtins.go new file mode 100644 index 00000000..4cc3a83d --- /dev/null +++ b/internal/tool/codegen_builtins.go @@ -0,0 +1,856 @@ +package tool + +// This file holds the built-in code-generation templates (data) registered on +// every new CodeGenerator. The CodeGenerator/CodeGenTool types and all +// rendering, listing, and suggestion logic live in codegen.go. + +func (cg *CodeGenerator) registerBuiltins() { + // Go templates + cg.Templates["go-handler"] = &CodeTemplate{ + Name: "go-handler", + Description: "HTTP handler function with request parsing, validation, and response", + Language: "go", + Template: `package {{.Package}} + +import ( + "encoding/json" + "net/http" +) + +// {{.Name}}Request represents the request body for {{.Name}}. +type {{.Name}}Request struct { + // TODO: define request fields +} + +// {{.Name}}Response represents the response body for {{.Name}}. +type {{.Name}}Response struct { + // TODO: define response fields +} + +// {{.Name}}Handler handles {{.Method}} requests for {{.Path}}. +func {{.Name}}Handler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.Method{{.Method}} { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + var req {{.Name}}Request + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "invalid request body", http.StatusBadRequest) + return + } + defer r.Body.Close() + + // TODO: implement handler logic + + resp := {{.Name}}Response{} + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(resp) +} +`, + Variables: []TemplateVar{ + {Name: "Package", Description: "Go package name", Required: true, Default: "handlers"}, + {Name: "Name", Description: "Handler name (PascalCase)", Required: true, Default: ""}, + {Name: "Method", Description: "HTTP method (Get, Post, Put, Delete)", Required: false, Default: "Post"}, + {Name: "Path", Description: "URL path for the endpoint", Required: false, Default: "/api/resource"}, + }, + Output: "{{.Name | lower}}_handler.go", + } + + cg.Templates["go-middleware"] = &CodeTemplate{ + Name: "go-middleware", + Description: "HTTP middleware with next handler chaining", + Language: "go", + Template: `package {{.Package}} + +import ( + "log" + "net/http" + "time" +) + +// {{.Name}} is a middleware that {{.Description}}. +func {{.Name}}(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + + // Pre-processing + log.Printf("[{{.Name}}] %s %s started", r.Method, r.URL.Path) + + // TODO: implement middleware logic + + // Call next handler + next.ServeHTTP(w, r) + + // Post-processing + duration := time.Since(start) + log.Printf("[{{.Name}}] %s %s completed in %v", r.Method, r.URL.Path, duration) + }) +} +`, + Variables: []TemplateVar{ + {Name: "Package", Description: "Go package name", Required: true, Default: "middleware"}, + {Name: "Name", Description: "Middleware function name", Required: true, Default: ""}, + {Name: "Description", Description: "What the middleware does", Required: false, Default: "processes requests"}, + }, + Output: "{{.Name | lower}}.go", + } + + cg.Templates["go-crud"] = &CodeTemplate{ + Name: "go-crud", + Description: "Full CRUD functions for a resource (Create, Get, List, Update, Delete)", + Language: "go", + Template: `package {{.Package}} + +import ( + "encoding/json" + "fmt" + "net/http" + "sync" +) + +// {{.Resource}} represents the {{.Resource}} entity. +type {{.Resource}} struct { + ID string ` + "`json:\"id\"`" + ` + Name string ` + "`json:\"name\"`" + ` + // TODO: add fields +} + +// {{.Resource}}Store manages {{.Resource}} persistence. +type {{.Resource}}Store struct { + mu sync.RWMutex + items map[string]*{{.Resource}} +} + +// New{{.Resource}}Store creates a new store. +func New{{.Resource}}Store() *{{.Resource}}Store { + return &{{.Resource}}Store{items: make(map[string]*{{.Resource}})} +} + +// Create{{.Resource}} adds a new {{.Resource}}. +func (s *{{.Resource}}Store) Create{{.Resource}}(item *{{.Resource}}) error { + s.mu.Lock() + defer s.mu.Unlock() + if _, exists := s.items[item.ID]; exists { + return fmt.Errorf("{{.Resource}} with ID %s already exists", item.ID) + } + s.items[item.ID] = item + return nil +} + +// Get{{.Resource}} retrieves a {{.Resource}} by ID. +func (s *{{.Resource}}Store) Get{{.Resource}}(id string) (*{{.Resource}}, error) { + s.mu.RLock() + defer s.mu.RUnlock() + item, ok := s.items[id] + if !ok { + return nil, fmt.Errorf("{{.Resource}} with ID %s not found", id) + } + return item, nil +} + +// List{{.Resource}}s returns all {{.Resource}} items. +func (s *{{.Resource}}Store) List{{.Resource}}s() []*{{.Resource}} { + s.mu.RLock() + defer s.mu.RUnlock() + result := make([]*{{.Resource}}, 0, len(s.items)) + for _, item := range s.items { + result = append(result, item) + } + return result +} + +// Update{{.Resource}} updates an existing {{.Resource}}. +func (s *{{.Resource}}Store) Update{{.Resource}}(item *{{.Resource}}) error { + s.mu.Lock() + defer s.mu.Unlock() + if _, exists := s.items[item.ID]; !exists { + return fmt.Errorf("{{.Resource}} with ID %s not found", item.ID) + } + s.items[item.ID] = item + return nil +} + +// Delete{{.Resource}} removes a {{.Resource}} by ID. +func (s *{{.Resource}}Store) Delete{{.Resource}}(id string) error { + s.mu.Lock() + defer s.mu.Unlock() + if _, exists := s.items[id]; !exists { + return fmt.Errorf("{{.Resource}} with ID %s not found", id) + } + delete(s.items, id) + return nil +} + +// Handle{{.Resource}}s returns an HTTP handler for {{.Resource}} CRUD operations. +func (s *{{.Resource}}Store) Handle{{.Resource}}s(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch r.Method { + case http.MethodGet: + items := s.List{{.Resource}}s() + _ = json.NewEncoder(w).Encode(items) + case http.MethodPost: + var item {{.Resource}} + if err := json.NewDecoder(r.Body).Decode(&item); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if err := s.Create{{.Resource}}(&item); err != nil { + http.Error(w, err.Error(), http.StatusConflict) + return + } + w.WriteHeader(http.StatusCreated) + _ = json.NewEncoder(w).Encode(item) + case http.MethodPut: + var item {{.Resource}} + if err := json.NewDecoder(r.Body).Decode(&item); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if err := s.Update{{.Resource}}(&item); err != nil { + http.Error(w, err.Error(), http.StatusNotFound) + return + } + _ = json.NewEncoder(w).Encode(item) + case http.MethodDelete: + id := r.URL.Query().Get("id") + if err := s.Delete{{.Resource}}(id); err != nil { + http.Error(w, err.Error(), http.StatusNotFound) + return + } + w.WriteHeader(http.StatusNoContent) + default: + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + } +} +`, + Variables: []TemplateVar{ + {Name: "Package", Description: "Go package name", Required: true, Default: "models"}, + {Name: "Resource", Description: "Resource name (PascalCase)", Required: true, Default: ""}, + }, + Output: "{{.Resource | lower}}_crud.go", + } + + cg.Templates["go-test-table"] = &CodeTemplate{ + Name: "go-test-table", + Description: "Table-driven test with subtests", + Language: "go", + Template: `package {{.Package}} + +import ( + "testing" +) + +func Test{{.Function}}(t *testing.T) { + tests := []struct { + name string + input string + want string + wantErr bool + }{ + { + name: "valid input", + input: "hello", + want: "expected", + wantErr: false, + }, + { + name: "empty input", + input: "", + want: "", + wantErr: true, + }, + // TODO: add more test cases + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := {{.Function}}(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("{{.Function}}() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("{{.Function}}() = %v, want %v", got, tt.want) + } + }) + } +} +`, + Variables: []TemplateVar{ + {Name: "Package", Description: "Go package name", Required: true, Default: ""}, + {Name: "Function", Description: "Function name to test", Required: true, Default: ""}, + }, + Output: "{{.Function | lower}}_test.go", + } + + cg.Templates["go-interface"] = &CodeTemplate{ + Name: "go-interface", + Description: "Interface definition with mock implementation", + Language: "go", + Template: `package {{.Package}} + +import ( + "context" + "sync" +) + +// {{.Name}} defines the interface for {{.Description}}. +type {{.Name}} interface { + Get(ctx context.Context, id string) (interface{}, error) + List(ctx context.Context) ([]interface{}, error) + Create(ctx context.Context, item interface{}) error + Update(ctx context.Context, id string, item interface{}) error + Delete(ctx context.Context, id string) error +} + +// Mock{{.Name}} is a test double for {{.Name}}. +type Mock{{.Name}} struct { + mu sync.Mutex + GetFunc func(ctx context.Context, id string) (interface{}, error) + ListFunc func(ctx context.Context) ([]interface{}, error) + CreateFunc func(ctx context.Context, item interface{}) error + UpdateFunc func(ctx context.Context, id string, item interface{}) error + DeleteFunc func(ctx context.Context, id string) error + Calls []string +} + +func (m *Mock{{.Name}}) Get(ctx context.Context, id string) (interface{}, error) { + m.mu.Lock() + m.Calls = append(m.Calls, "Get") + m.mu.Unlock() + if m.GetFunc != nil { + return m.GetFunc(ctx, id) + } + return nil, nil +} + +func (m *Mock{{.Name}}) List(ctx context.Context) ([]interface{}, error) { + m.mu.Lock() + m.Calls = append(m.Calls, "List") + m.mu.Unlock() + if m.ListFunc != nil { + return m.ListFunc(ctx) + } + return nil, nil +} + +func (m *Mock{{.Name}}) Create(ctx context.Context, item interface{}) error { + m.mu.Lock() + m.Calls = append(m.Calls, "Create") + m.mu.Unlock() + if m.CreateFunc != nil { + return m.CreateFunc(ctx, item) + } + return nil +} + +func (m *Mock{{.Name}}) Update(ctx context.Context, id string, item interface{}) error { + m.mu.Lock() + m.Calls = append(m.Calls, "Update") + m.mu.Unlock() + if m.UpdateFunc != nil { + return m.UpdateFunc(ctx, id, item) + } + return nil +} + +func (m *Mock{{.Name}}) Delete(ctx context.Context, id string) error { + m.mu.Lock() + m.Calls = append(m.Calls, "Delete") + m.mu.Unlock() + if m.DeleteFunc != nil { + return m.DeleteFunc(ctx, id) + } + return nil +} +`, + Variables: []TemplateVar{ + {Name: "Package", Description: "Go package name", Required: true, Default: ""}, + {Name: "Name", Description: "Interface name (PascalCase)", Required: true, Default: ""}, + {Name: "Description", Description: "What the interface represents", Required: false, Default: "a service"}, + }, + Output: "{{.Name | lower}}.go", + } + + cg.Templates["go-errors"] = &CodeTemplate{ + Name: "go-errors", + Description: "Custom error type with constructors", + Language: "go", + Template: `package {{.Package}} + +import "fmt" + +// {{.Name}}Error represents an error in the {{.Domain}} domain. +type {{.Name}}Error struct { + Code string + Message string + Err error +} + +func (e *{{.Name}}Error) Error() string { + if e.Err != nil { + return fmt.Sprintf("[%s] %s: %v", e.Code, e.Message, e.Err) + } + return fmt.Sprintf("[%s] %s", e.Code, e.Message) +} + +func (e *{{.Name}}Error) Unwrap() error { + return e.Err +} + +// New{{.Name}}Error creates a new {{.Name}}Error. +func New{{.Name}}Error(code, message string) *{{.Name}}Error { + return &{{.Name}}Error{Code: code, Message: message} +} + +// Wrap{{.Name}}Error wraps an existing error with a {{.Name}}Error. +func Wrap{{.Name}}Error(code, message string, err error) *{{.Name}}Error { + return &{{.Name}}Error{Code: code, Message: message, Err: err} +} + +// ErrNotFound creates a not-found error. +func ErrNotFound(resource, id string) *{{.Name}}Error { + return New{{.Name}}Error("NOT_FOUND", fmt.Sprintf("%s with ID %s not found", resource, id)) +} + +// ErrValidation creates a validation error. +func ErrValidation(field, reason string) *{{.Name}}Error { + return New{{.Name}}Error("VALIDATION", fmt.Sprintf("field %s: %s", field, reason)) +} + +// ErrInternal creates an internal error wrapping the cause. +func ErrInternal(message string, err error) *{{.Name}}Error { + return Wrap{{.Name}}Error("INTERNAL", message, err) +} + +// Is{{.Name}}Error checks if an error is a {{.Name}}Error with the given code. +func Is{{.Name}}Error(err error, code string) bool { + if e, ok := err.(*{{.Name}}Error); ok { + return e.Code == code + } + return false +} +`, + Variables: []TemplateVar{ + {Name: "Package", Description: "Go package name", Required: true, Default: ""}, + {Name: "Name", Description: "Error type prefix (PascalCase)", Required: true, Default: ""}, + {Name: "Domain", Description: "Domain this error belongs to", Required: false, Default: "application"}, + }, + Output: "{{.Name | lower}}_errors.go", + } + + cg.Templates["go-config"] = &CodeTemplate{ + Name: "go-config", + Description: "Config struct with environment variable loading and validation", + Language: "go", + Template: `package {{.Package}} + +import ( + "fmt" + "os" + "strconv" + "strings" +) + +// {{.Name}}Config holds configuration for {{.Description}}. +type {{.Name}}Config struct { + Host string + Port int + Debug bool + LogLevel string + // TODO: add more config fields +} + +// Default{{.Name}}Config returns the default configuration. +func Default{{.Name}}Config() *{{.Name}}Config { + return &{{.Name}}Config{ + Host: "localhost", + Port: 8080, + Debug: false, + LogLevel: "info", + } +} + +// Load{{.Name}}Config loads configuration from environment variables. +// Variables are prefixed with {{.Prefix}}_. +func Load{{.Name}}Config() (*{{.Name}}Config, error) { + cfg := Default{{.Name}}Config() + + if v := os.Getenv("{{.Prefix}}_HOST"); v != "" { + cfg.Host = v + } + if v := os.Getenv("{{.Prefix}}_PORT"); v != "" { + port, err := strconv.Atoi(v) + if err != nil { + return nil, fmt.Errorf("invalid {{.Prefix}}_PORT: %w", err) + } + cfg.Port = port + } + if v := os.Getenv("{{.Prefix}}_DEBUG"); v != "" { + cfg.Debug = strings.ToLower(v) == "true" || v == "1" + } + if v := os.Getenv("{{.Prefix}}_LOG_LEVEL"); v != "" { + cfg.LogLevel = v + } + + if err := cfg.Validate(); err != nil { + return nil, err + } + return cfg, nil +} + +// Validate checks that the configuration is valid. +func (c *{{.Name}}Config) Validate() error { + if c.Host == "" { + return fmt.Errorf("host must not be empty") + } + if c.Port < 1 || c.Port > 65535 { + return fmt.Errorf("port must be between 1 and 65535, got %d", c.Port) + } + validLevels := map[string]bool{"debug": true, "info": true, "warn": true, "error": true} + if !validLevels[strings.ToLower(c.LogLevel)] { + return fmt.Errorf("invalid log level: %s (must be debug, info, warn, or error)", c.LogLevel) + } + return nil +} + +// Address returns the host:port address string. +func (c *{{.Name}}Config) Address() string { + return fmt.Sprintf("%s:%d", c.Host, c.Port) +} +`, + Variables: []TemplateVar{ + {Name: "Package", Description: "Go package name", Required: true, Default: "config"}, + {Name: "Name", Description: "Config name (PascalCase)", Required: true, Default: ""}, + {Name: "Prefix", Description: "Environment variable prefix (UPPER_CASE)", Required: true, Default: "APP"}, + {Name: "Description", Description: "What this config is for", Required: false, Default: "the application"}, + }, + Output: "{{.Name | lower}}_config.go", + } + + // Python templates + cg.Templates["py-fastapi-endpoint"] = &CodeTemplate{ + Name: "py-fastapi-endpoint", + Description: "FastAPI route with Pydantic model", + Language: "python", + Template: `from fastapi import APIRouter, HTTPException +from pydantic import BaseModel, Field +from typing import Optional + +router = APIRouter(prefix="/{{.Prefix}}", tags=["{{.Tag}}"]) + + +class {{.Name}}Request(BaseModel): + """Request model for {{.Name}}.""" + name: str = Field(..., description="Name field") + # TODO: add request fields + + +class {{.Name}}Response(BaseModel): + """Response model for {{.Name}}.""" + id: str + name: str + # TODO: add response fields + + +@router.post("/", response_model={{.Name}}Response, status_code=201) +async def create_{{.NameLower}}(request: {{.Name}}Request) -> {{.Name}}Response: + """Create a new {{.Name}}.""" + # TODO: implement creation logic + return {{.Name}}Response(id="generated-id", name=request.name) + + +@router.get("/{item_id}", response_model={{.Name}}Response) +async def get_{{.NameLower}}(item_id: str) -> {{.Name}}Response: + """Get a {{.Name}} by ID.""" + # TODO: implement retrieval logic + raise HTTPException(status_code=404, detail="{{.Name}} not found") + + +@router.put("/{item_id}", response_model={{.Name}}Response) +async def update_{{.NameLower}}(item_id: str, request: {{.Name}}Request) -> {{.Name}}Response: + """Update a {{.Name}}.""" + # TODO: implement update logic + return {{.Name}}Response(id=item_id, name=request.name) + + +@router.delete("/{item_id}", status_code=204) +async def delete_{{.NameLower}}(item_id: str) -> None: + """Delete a {{.Name}}.""" + # TODO: implement deletion logic + pass +`, + Variables: []TemplateVar{ + {Name: "Name", Description: "Resource name (PascalCase)", Required: true, Default: ""}, + {Name: "NameLower", Description: "Resource name (lowercase)", Required: true, Default: ""}, + {Name: "Prefix", Description: "URL prefix", Required: false, Default: "api"}, + {Name: "Tag", Description: "OpenAPI tag", Required: false, Default: "default"}, + }, + Output: "{{.NameLower}}_router.py", + } + + cg.Templates["py-test-class"] = &CodeTemplate{ + Name: "py-test-class", + Description: "Pytest test class with setup/teardown", + Language: "python", + Template: `import pytest + + +class Test{{.Name}}: + """Tests for {{.Name}}.""" + + def setup_method(self): + """Set up test fixtures.""" + # TODO: initialize test fixtures + self.subject = None + + def teardown_method(self): + """Clean up after tests.""" + # TODO: clean up resources + pass + + def test_{{.MethodUnderTest}}_with_valid_input(self): + """Test {{.MethodUnderTest}} with valid input.""" + # Arrange + expected = None # TODO: set expected value + + # Act + result = self.subject.{{.MethodUnderTest}}() + + # Assert + assert result == expected + + def test_{{.MethodUnderTest}}_with_invalid_input(self): + """Test {{.MethodUnderTest}} raises on invalid input.""" + with pytest.raises(ValueError): + self.subject.{{.MethodUnderTest}}(None) + + def test_{{.MethodUnderTest}}_edge_case(self): + """Test {{.MethodUnderTest}} handles edge cases.""" + # TODO: implement edge case test + pass +`, + Variables: []TemplateVar{ + {Name: "Name", Description: "Class under test (PascalCase)", Required: true, Default: ""}, + {Name: "MethodUnderTest", Description: "Primary method to test", Required: true, Default: "execute"}, + }, + Output: "test_{{.Name | lower}}.py", + } + + cg.Templates["py-dataclass"] = &CodeTemplate{ + Name: "py-dataclass", + Description: "Dataclass with validation", + Language: "python", + Template: `from dataclasses import dataclass, field +from typing import Optional, List + + +@dataclass +class {{.Name}}: + """{{.Description}}""" + + name: str + value: int = 0 + tags: List[str] = field(default_factory=list) + metadata: Optional[str] = None + + def __post_init__(self): + """Validate fields after initialization.""" + if not self.name: + raise ValueError("name must not be empty") + if self.value < 0: + raise ValueError("value must be non-negative") + # TODO: add more validation + + def to_dict(self) -> dict: + """Convert to dictionary.""" + return { + "name": self.name, + "value": self.value, + "tags": list(self.tags), + "metadata": self.metadata, + } + + @classmethod + def from_dict(cls, data: dict) -> "{{.Name}}": + """Create instance from dictionary.""" + return cls( + name=data["name"], + value=data.get("value", 0), + tags=data.get("tags", []), + metadata=data.get("metadata"), + ) +`, + Variables: []TemplateVar{ + {Name: "Name", Description: "Class name (PascalCase)", Required: true, Default: ""}, + {Name: "Description", Description: "Class description", Required: false, Default: "A data model"}, + }, + Output: "{{.Name | lower}}.py", + } + + // TypeScript templates + cg.Templates["ts-react-component"] = &CodeTemplate{ + Name: "ts-react-component", + Description: "Functional React component with props interface", + Language: "typescript", + Template: `import React from 'react'; + +interface {{.Name}}Props { + title: string; + className?: string; + children?: React.ReactNode; + onClick?: () => void; +} + +/** + * {{.Description}} + */ +export const {{.Name}}: React.FC<{{.Name}}Props> = ({ + title, + className = '', + children, + onClick, +}) => { + return ( +
+

{title}

+ {children} +
+ ); +}; + +export default {{.Name}}; +`, + Variables: []TemplateVar{ + {Name: "Name", Description: "Component name (PascalCase)", Required: true, Default: ""}, + {Name: "Description", Description: "Component description", Required: false, Default: "A React component"}, + }, + Output: "{{.Name}}.tsx", + } + + cg.Templates["ts-express-router"] = &CodeTemplate{ + Name: "ts-express-router", + Description: "Express router with middleware", + Language: "typescript", + Template: `import { Router, Request, Response, NextFunction } from 'express'; + +const router = Router(); + +// Middleware for this router +function validate{{.Name}}(req: Request, res: Response, next: NextFunction): void { + // TODO: implement validation + next(); +} + +// GET /{{.Path}} +router.get('/', async (req: Request, res: Response) => { + try { + // TODO: implement list + res.json({ items: [] }); + } catch (error) { + res.status(500).json({ error: 'Internal server error' }); + } +}); + +// GET /{{.Path}}/:id +router.get('/:id', async (req: Request, res: Response) => { + try { + const { id } = req.params; + // TODO: implement get by id + res.json({ id }); + } catch (error) { + res.status(500).json({ error: 'Internal server error' }); + } +}); + +// POST /{{.Path}} +router.post('/', validate{{.Name}}, async (req: Request, res: Response) => { + try { + // TODO: implement create + res.status(201).json({ id: 'new-id', ...req.body }); + } catch (error) { + res.status(500).json({ error: 'Internal server error' }); + } +}); + +// PUT /{{.Path}}/:id +router.put('/:id', validate{{.Name}}, async (req: Request, res: Response) => { + try { + const { id } = req.params; + // TODO: implement update + res.json({ id, ...req.body }); + } catch (error) { + res.status(500).json({ error: 'Internal server error' }); + } +}); + +// DELETE /{{.Path}}/:id +router.delete('/:id', async (req: Request, res: Response) => { + try { + const { id } = req.params; + // TODO: implement delete + res.status(204).send(); + } catch (error) { + res.status(500).json({ error: 'Internal server error' }); + } +}); + +export default router; +`, + Variables: []TemplateVar{ + {Name: "Name", Description: "Resource name (PascalCase)", Required: true, Default: ""}, + {Name: "Path", Description: "Route path", Required: false, Default: "resources"}, + }, + Output: "{{.Name | lower}}.router.ts", + } + + cg.Templates["ts-test-describe"] = &CodeTemplate{ + Name: "ts-test-describe", + Description: "Jest/Vitest describe block with test cases", + Language: "typescript", + Template: `import { describe, it, expect, beforeEach, afterEach } from 'vitest'; + +describe('{{.Name}}', () => { + let subject: any; + + beforeEach(() => { + // TODO: set up test fixtures + subject = null; + }); + + afterEach(() => { + // TODO: clean up + }); + + describe('{{.Method}}', () => { + it('should handle valid input', () => { + // Arrange + const input = {}; + + // Act + const result = subject.{{.Method}}(input); + + // Assert + expect(result).toBeDefined(); + }); + + it('should throw on invalid input', () => { + expect(() => subject.{{.Method}}(null)).toThrow(); + }); + + it('should handle edge cases', () => { + // TODO: implement edge case test + expect(true).toBe(true); + }); + }); +}); +`, + Variables: []TemplateVar{ + {Name: "Name", Description: "Module/class under test", Required: true, Default: ""}, + {Name: "Method", Description: "Method being tested", Required: true, Default: "execute"}, + }, + Output: "{{.Name | lower}}.test.ts", + } +} From e46e0ee0288ac9ce35e0065ea53a774073ff2633 Mon Sep 17 00:00:00 2001 From: Lakshman Patel Date: Fri, 19 Jun 2026 19:14:01 +0530 Subject: [PATCH 09/20] refactor(eval): split Go benchmark tasks 9-15 into tasks_go_more.go --- internal/feature/eval/tasks_go.go | 666 ------------------------- internal/feature/eval/tasks_go_more.go | 626 +++++++++++++++++++++++ 2 files changed, 626 insertions(+), 666 deletions(-) create mode 100644 internal/feature/eval/tasks_go_more.go diff --git a/internal/feature/eval/tasks_go.go b/internal/feature/eval/tasks_go.go index 5c1ec2b0..53786548 100644 --- a/internal/feature/eval/tasks_go.go +++ b/internal/feature/eval/tasks_go.go @@ -664,669 +664,3 @@ func TestParsePersonInvalid(t *testing.T) { ValidateFn: helperValidateBuildAndTest, } } - -// Task 9: Add context cancellation -func taskAddContextCancellation() BenchmarkTask { - return BenchmarkTask{ - ID: "go-add-context-cancellation", - Description: "Add context cancellation support to a long-running operation", - Prompt: "Add context.Context support to the Process function so it can be cancelled. The function should check for context cancellation between iterations and return ctx.Err() if cancelled.", - TimeLimit: 2 * time.Minute, - Tags: []string{"go", "context", "cancellation"}, - SetupFn: func(workDir string) error { - if err := helperInitModule(workDir, "ctxcancel"); err != nil { - return err - } - code := `package main - -import "context" - -// Process performs work in a loop, respecting context cancellation. -func Process(ctx context.Context, items []string) ([]string, error) { - results := make([]string, 0, len(items)) - for _, item := range items { - select { - case <-ctx.Done(): - return results, ctx.Err() - default: - } - results = append(results, "processed: "+item) - } - return results, nil -} -` - if err := helperWriteFile(workDir, "main.go", code); err != nil { - return err - } - - test := `package main - -import ( - "context" - "testing" -) - -func TestProcessSuccess(t *testing.T) { - ctx := context.Background() - items := []string{"a", "b", "c"} - results, err := Process(ctx, items) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if len(results) != 3 { - t.Fatalf("expected 3 results, got %d", len(results)) - } - if results[0] != "processed: a" { - t.Errorf("results[0] = %q, want %q", results[0], "processed: a") - } -} - -func TestProcessCancellation(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - cancel() // cancel immediately - - items := []string{"a", "b", "c"} - _, err := Process(ctx, items) - if err == nil { - t.Error("expected error from cancelled context") - } - if err != context.Canceled { - t.Errorf("expected context.Canceled, got %v", err) - } -} -` - return helperWriteFile(workDir, "main_test.go", test) - }, - ValidateFn: helperValidateBuildAndTest, - } -} - -// Task 10: Fix an off-by-one error -func taskFixOffByOne() BenchmarkTask { - return BenchmarkTask{ - ID: "go-fix-off-by-one", - Description: "Fix an off-by-one error in a pagination function", - Prompt: "Fix the off-by-one error in the Paginate function. It should return the correct slice of items for the given page number (1-based) and page size.", - TimeLimit: 2 * time.Minute, - Tags: []string{"go", "bug-fix", "off-by-one"}, - SetupFn: func(workDir string) error { - if err := helperInitModule(workDir, "paginate"); err != nil { - return err - } - code := `package main - -// Paginate returns a page of items. Page is 1-based. -func Paginate(items []int, page, pageSize int) []int { - if page < 1 || pageSize < 1 { - return nil - } - start := (page - 1) * pageSize - if start >= len(items) { - return nil - } - end := start + pageSize - if end > len(items) { - end = len(items) - } - return items[start:end] -} -` - if err := helperWriteFile(workDir, "main.go", code); err != nil { - return err - } - - test := `package main - -import "testing" - -func TestPaginate(t *testing.T) { - items := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} - - tests := []struct { - page, size int - want []int - }{ - {1, 3, []int{1, 2, 3}}, - {2, 3, []int{4, 5, 6}}, - {3, 3, []int{7, 8, 9}}, - {4, 3, []int{10}}, - {5, 3, nil}, - {1, 10, []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}}, - {0, 3, nil}, - {1, 0, nil}, - } - - for _, tt := range tests { - got := Paginate(items, tt.page, tt.size) - if !intSliceEqual(got, tt.want) { - t.Errorf("Paginate(items, %d, %d) = %v, want %v", tt.page, tt.size, got, tt.want) - } - } -} - -func intSliceEqual(a, b []int) bool { - if len(a) == 0 && len(b) == 0 { - return true - } - if len(a) != len(b) { - return false - } - for i := range a { - if a[i] != b[i] { - return false - } - } - return true -} -` - return helperWriteFile(workDir, "main_test.go", test) - }, - ValidateFn: helperValidateBuildAndTest, - } -} - -// Task 11: Implement retry with backoff -func taskImplementRetryBackoff() BenchmarkTask { - return BenchmarkTask{ - ID: "go-implement-retry-backoff", - Description: "Implement a retry function with exponential backoff", - Prompt: "Implement the Retry function that retries a fallible operation with exponential backoff. It should retry up to maxRetries times, doubling the wait between each attempt starting from initialDelay.", - TimeLimit: 3 * time.Minute, - Tags: []string{"go", "retry", "backoff"}, - SetupFn: func(workDir string) error { - if err := helperInitModule(workDir, "retrybackoff"); err != nil { - return err - } - code := `package main - -import "time" - -// Retry retries fn up to maxRetries times with exponential backoff. -// initialDelay is doubled after each failed attempt. -func Retry(fn func() error, maxRetries int, initialDelay time.Duration) error { - var err error - delay := initialDelay - for i := 0; i <= maxRetries; i++ { - err = fn() - if err == nil { - return nil - } - if i < maxRetries { - time.Sleep(delay) - delay *= 2 - } - } - return err -} -` - if err := helperWriteFile(workDir, "main.go", code); err != nil { - return err - } - - test := `package main - -import ( - "errors" - "testing" - "time" -) - -func TestRetrySuccess(t *testing.T) { - calls := 0 - fn := func() error { - calls++ - if calls < 3 { - return errors.New("not yet") - } - return nil - } - - err := Retry(fn, 5, 1*time.Millisecond) - if err != nil { - t.Fatalf("expected nil error, got %v", err) - } - if calls != 3 { - t.Errorf("expected 3 calls, got %d", calls) - } -} - -func TestRetryExhausted(t *testing.T) { - fn := func() error { - return errors.New("always fails") - } - - err := Retry(fn, 3, 1*time.Millisecond) - if err == nil { - t.Error("expected error after retries exhausted") - } -} - -func TestRetryImmediateSuccess(t *testing.T) { - calls := 0 - fn := func() error { - calls++ - return nil - } - - err := Retry(fn, 3, 1*time.Millisecond) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if calls != 1 { - t.Errorf("expected 1 call, got %d", calls) - } -} -` - return helperWriteFile(workDir, "main_test.go", test) - }, - ValidateFn: helperValidateBuildAndTest, - } -} - -// Task 12: Convert callback to channel pattern -func taskCallbackToChannel() BenchmarkTask { - return BenchmarkTask{ - ID: "go-callback-to-channel", - Description: "Convert a callback-based API to use channels", - Prompt: "Implement the StreamResults function that converts the callback-based FetchWithCallback into a channel-based API. It should return a channel that receives results as they arrive.", - TimeLimit: 2 * time.Minute, - Tags: []string{"go", "channels", "concurrency"}, - SetupFn: func(workDir string) error { - if err := helperInitModule(workDir, "chanpattern"); err != nil { - return err - } - code := `package main - -// FetchWithCallback calls the callback for each result. -func FetchWithCallback(items []string, cb func(string)) { - for _, item := range items { - cb("result: " + item) - } -} - -// StreamResults converts callback-based FetchWithCallback into channel-based API. -func StreamResults(items []string) <-chan string { - ch := make(chan string) - go func() { - defer close(ch) - FetchWithCallback(items, func(result string) { - ch <- result - }) - }() - return ch -} -` - if err := helperWriteFile(workDir, "main.go", code); err != nil { - return err - } - - test := `package main - -import "testing" - -func TestStreamResults(t *testing.T) { - items := []string{"a", "b", "c"} - ch := StreamResults(items) - - var results []string - for r := range ch { - results = append(results, r) - } - - if len(results) != 3 { - t.Fatalf("expected 3 results, got %d", len(results)) - } - expected := []string{"result: a", "result: b", "result: c"} - for i, r := range results { - if r != expected[i] { - t.Errorf("results[%d] = %q, want %q", i, r, expected[i]) - } - } -} - -func TestStreamResultsEmpty(t *testing.T) { - ch := StreamResults(nil) - var results []string - for r := range ch { - results = append(results, r) - } - if len(results) != 0 { - t.Errorf("expected 0 results, got %d", len(results)) - } -} -` - return helperWriteFile(workDir, "main_test.go", test) - }, - ValidateFn: helperValidateBuildAndTest, - } -} - -// Task 13: Add input validation -func taskAddInputValidation() BenchmarkTask { - return BenchmarkTask{ - ID: "go-add-input-validation", - Description: "Add input validation to a user registration function", - Prompt: "Add input validation to the Register function. Validate that: name is non-empty (max 100 chars), email contains '@' and '.', age is between 0 and 150, password is at least 8 chars.", - TimeLimit: 2 * time.Minute, - Tags: []string{"go", "validation"}, - SetupFn: func(workDir string) error { - if err := helperInitModule(workDir, "validation"); err != nil { - return err - } - code := `package main - -import ( - "errors" - "strings" -) - -// Register validates and registers a user. -func Register(name, email string, age int, password string) error { - if name == "" || len(name) > 100 { - return errors.New("invalid name: must be 1-100 characters") - } - if !strings.Contains(email, "@") || !strings.Contains(email, ".") { - return errors.New("invalid email: must contain @ and .") - } - if age < 0 || age > 150 { - return errors.New("invalid age: must be 0-150") - } - if len(password) < 8 { - return errors.New("invalid password: must be at least 8 characters") - } - return nil -} -` - if err := helperWriteFile(workDir, "main.go", code); err != nil { - return err - } - - _ = `package main - -import "testing" - -func TestRegisterValid(t *testing.T) { - err := Register("Alice", "alice@example.com", 25, "securepass") - if err != nil { - t.Errorf("unexpected error: %v", err) - } -} - -func TestRegisterInvalidName(t *testing.T) { - if err := Register("", "a@b.c", 25, "securepass"); err == nil { - t.Error("expected error for empty name") - } - longName := strings.Repeat("a", 101) - if err := Register(longName, "a@b.c", 25, "securepass"); err == nil { - t.Error("expected error for name > 100 chars") - } -} - -func TestRegisterInvalidEmail(t *testing.T) { - if err := Register("Alice", "invalid", 25, "securepass"); err == nil { - t.Error("expected error for email without @") - } - if err := Register("Alice", "no@dot", 25, "securepass"); err == nil { - t.Error("expected error for email without .") - } -} - -func TestRegisterInvalidAge(t *testing.T) { - if err := Register("Alice", "a@b.c", -1, "securepass"); err == nil { - t.Error("expected error for negative age") - } - if err := Register("Alice", "a@b.c", 151, "securepass"); err == nil { - t.Error("expected error for age > 150") - } -} - -func TestRegisterInvalidPassword(t *testing.T) { - if err := Register("Alice", "a@b.c", 25, "short"); err == nil { - t.Error("expected error for password < 8 chars") - } -} -` - // Need to add strings import to test - test := `package main - -import ( - "strings" - "testing" -) - -func TestRegisterValid(t *testing.T) { - err := Register("Alice", "alice@example.com", 25, "securepass") - if err != nil { - t.Errorf("unexpected error: %v", err) - } -} - -func TestRegisterInvalidName(t *testing.T) { - if err := Register("", "a@b.c", 25, "securepass"); err == nil { - t.Error("expected error for empty name") - } - longName := strings.Repeat("a", 101) - if err := Register(longName, "a@b.c", 25, "securepass"); err == nil { - t.Error("expected error for name > 100 chars") - } -} - -func TestRegisterInvalidEmail(t *testing.T) { - if err := Register("Alice", "invalid", 25, "securepass"); err == nil { - t.Error("expected error for email without @") - } - if err := Register("Alice", "no@dot", 25, "securepass"); err == nil { - t.Error("expected error for email without .") - } -} - -func TestRegisterInvalidAge(t *testing.T) { - if err := Register("Alice", "a@b.c", -1, "securepass"); err == nil { - t.Error("expected error for negative age") - } - if err := Register("Alice", "a@b.c", 151, "securepass"); err == nil { - t.Error("expected error for age > 150") - } -} - -func TestRegisterInvalidPassword(t *testing.T) { - if err := Register("Alice", "a@b.c", 25, "short"); err == nil { - t.Error("expected error for password < 8 chars") - } -} -` - return helperWriteFile(workDir, "main_test.go", test) - }, - ValidateFn: helperValidateBuildAndTest, - } -} - -// Task 14: Fix a goroutine leak -func taskFixGoroutineLeak() BenchmarkTask { - return BenchmarkTask{ - ID: "go-fix-goroutine-leak", - Description: "Fix a goroutine leak in a producer function", - Prompt: "Fix the goroutine leak in the Produce function. The goroutine should stop when the done channel is closed, and the returned channel should be properly closed when the goroutine exits.", - TimeLimit: 2 * time.Minute, - Tags: []string{"go", "concurrency", "goroutine-leak"}, - SetupFn: func(workDir string) error { - if err := helperInitModule(workDir, "leakfix"); err != nil { - return err - } - code := `package main - -// Produce generates sequential numbers until done is closed. -func Produce(done <-chan struct{}) <-chan int { - ch := make(chan int) - go func() { - defer close(ch) - i := 0 - for { - select { - case <-done: - return - case ch <- i: - i++ - } - } - }() - return ch -} -` - if err := helperWriteFile(workDir, "main.go", code); err != nil { - return err - } - - test := `package main - -import "testing" - -func TestProduceAndStop(t *testing.T) { - done := make(chan struct{}) - ch := Produce(done) - - // Read a few values. - for i := 0; i < 5; i++ { - val := <-ch - if val != i { - t.Errorf("expected %d, got %d", i, val) - } - } - - // Signal done. - close(done) - - // Channel should eventually be closed. - // Drain any remaining buffered values. - drained := false - for range ch { - drained = true - _ = drained - } - // If we get here, the channel was closed properly. -} - -func TestProduceImmediateStop(t *testing.T) { - done := make(chan struct{}) - close(done) - ch := Produce(done) - - // Channel should be closed without producing values. - count := 0 - for range ch { - count++ - } - // Might get 0 or a very small number. - if count > 1 { - t.Errorf("expected at most 1 value after immediate close, got %d", count) - } -} -` - return helperWriteFile(workDir, "main_test.go", test) - }, - ValidateFn: helperValidateBuildAndTest, - } -} - -// Task 15: Implement a simple HTTP handler -func taskImplementHTTPHandler() BenchmarkTask { - return BenchmarkTask{ - ID: "go-implement-http-handler", - Description: "Implement a simple HTTP handler that returns JSON responses", - Prompt: "Implement the HealthHandler function that returns a JSON response with status 200 and body {\"status\":\"ok\",\"service\":\"hawk\"}. Also implement NotFoundHandler that returns 404 with {\"error\":\"not found\"}.", - TimeLimit: 2 * time.Minute, - Tags: []string{"go", "http", "handler"}, - SetupFn: func(workDir string) error { - if err := helperInitModule(workDir, "httphandler"); err != nil { - return err - } - code := `package main - -import ( - "encoding/json" - "net/http" -) - -// HealthHandler returns a JSON health check response. -func HealthHandler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]string{ - "status": "ok", - "service": "hawk", - }) -} - -// NotFoundHandler returns a JSON 404 response. -func NotFoundHandler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusNotFound) - json.NewEncoder(w).Encode(map[string]string{ - "error": "not found", - }) -} -` - if err := helperWriteFile(workDir, "main.go", code); err != nil { - return err - } - - test := `package main - -import ( - "encoding/json" - "net/http" - "net/http/httptest" - "testing" -) - -func TestHealthHandler(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/health", nil) - w := httptest.NewRecorder() - - HealthHandler(w, req) - - if w.Code != http.StatusOK { - t.Errorf("status = %d, want %d", w.Code, http.StatusOK) - } - - ct := w.Header().Get("Content-Type") - if ct != "application/json" { - t.Errorf("Content-Type = %q, want %q", ct, "application/json") - } - - var body map[string]string - if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil { - t.Fatalf("failed to parse JSON: %v", err) - } - if body["status"] != "ok" { - t.Errorf("status = %q, want %q", body["status"], "ok") - } - if body["service"] != "hawk" { - t.Errorf("service = %q, want %q", body["service"], "hawk") - } -} - -func TestNotFoundHandler(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/missing", nil) - w := httptest.NewRecorder() - - NotFoundHandler(w, req) - - if w.Code != http.StatusNotFound { - t.Errorf("status = %d, want %d", w.Code, http.StatusNotFound) - } - - var body map[string]string - if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil { - t.Fatalf("failed to parse JSON: %v", err) - } - if body["error"] != "not found" { - t.Errorf("error = %q, want %q", body["error"], "not found") - } -} -` - return helperWriteFile(workDir, "main_test.go", test) - }, - ValidateFn: helperValidateBuildAndTest, - } -} diff --git a/internal/feature/eval/tasks_go_more.go b/internal/feature/eval/tasks_go_more.go new file mode 100644 index 00000000..4122d813 --- /dev/null +++ b/internal/feature/eval/tasks_go_more.go @@ -0,0 +1,626 @@ +package eval + +import "time" + +// This file holds Go benchmark tasks 9-15. Tasks 1-8, the GoTasks() suite +// builder, and the shared helpers live in tasks_go.go. + +// Task 9: Add context cancellation +func taskAddContextCancellation() BenchmarkTask { + return BenchmarkTask{ + ID: "go-add-context-cancellation", + Description: "Add context cancellation support to a long-running operation", + Prompt: "Add context.Context support to the Process function so it can be cancelled. The function should check for context cancellation between iterations and return ctx.Err() if cancelled.", + TimeLimit: 2 * time.Minute, + Tags: []string{"go", "context", "cancellation"}, + SetupFn: func(workDir string) error { + if err := helperInitModule(workDir, "ctxcancel"); err != nil { + return err + } + code := `package main + +import "context" + +// Process performs work in a loop, respecting context cancellation. +func Process(ctx context.Context, items []string) ([]string, error) { + results := make([]string, 0, len(items)) + for _, item := range items { + select { + case <-ctx.Done(): + return results, ctx.Err() + default: + } + results = append(results, "processed: "+item) + } + return results, nil +} +` + if err := helperWriteFile(workDir, "main.go", code); err != nil { + return err + } + + test := `package main + +import ( + "context" + "testing" +) + +func TestProcessSuccess(t *testing.T) { + ctx := context.Background() + items := []string{"a", "b", "c"} + results, err := Process(ctx, items) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(results) != 3 { + t.Fatalf("expected 3 results, got %d", len(results)) + } + if results[0] != "processed: a" { + t.Errorf("results[0] = %q, want %q", results[0], "processed: a") + } +} + +func TestProcessCancellation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel immediately + + items := []string{"a", "b", "c"} + _, err := Process(ctx, items) + if err == nil { + t.Error("expected error from cancelled context") + } + if err != context.Canceled { + t.Errorf("expected context.Canceled, got %v", err) + } +} +` + return helperWriteFile(workDir, "main_test.go", test) + }, + ValidateFn: helperValidateBuildAndTest, + } +} + +// Task 10: Fix an off-by-one error +func taskFixOffByOne() BenchmarkTask { + return BenchmarkTask{ + ID: "go-fix-off-by-one", + Description: "Fix an off-by-one error in a pagination function", + Prompt: "Fix the off-by-one error in the Paginate function. It should return the correct slice of items for the given page number (1-based) and page size.", + TimeLimit: 2 * time.Minute, + Tags: []string{"go", "bug-fix", "off-by-one"}, + SetupFn: func(workDir string) error { + if err := helperInitModule(workDir, "paginate"); err != nil { + return err + } + code := `package main + +// Paginate returns a page of items. Page is 1-based. +func Paginate(items []int, page, pageSize int) []int { + if page < 1 || pageSize < 1 { + return nil + } + start := (page - 1) * pageSize + if start >= len(items) { + return nil + } + end := start + pageSize + if end > len(items) { + end = len(items) + } + return items[start:end] +} +` + if err := helperWriteFile(workDir, "main.go", code); err != nil { + return err + } + + test := `package main + +import "testing" + +func TestPaginate(t *testing.T) { + items := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} + + tests := []struct { + page, size int + want []int + }{ + {1, 3, []int{1, 2, 3}}, + {2, 3, []int{4, 5, 6}}, + {3, 3, []int{7, 8, 9}}, + {4, 3, []int{10}}, + {5, 3, nil}, + {1, 10, []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}}, + {0, 3, nil}, + {1, 0, nil}, + } + + for _, tt := range tests { + got := Paginate(items, tt.page, tt.size) + if !intSliceEqual(got, tt.want) { + t.Errorf("Paginate(items, %d, %d) = %v, want %v", tt.page, tt.size, got, tt.want) + } + } +} + +func intSliceEqual(a, b []int) bool { + if len(a) == 0 && len(b) == 0 { + return true + } + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} +` + return helperWriteFile(workDir, "main_test.go", test) + }, + ValidateFn: helperValidateBuildAndTest, + } +} + +// Task 11: Implement retry with backoff +func taskImplementRetryBackoff() BenchmarkTask { + return BenchmarkTask{ + ID: "go-implement-retry-backoff", + Description: "Implement a retry function with exponential backoff", + Prompt: "Implement the Retry function that retries a fallible operation with exponential backoff. It should retry up to maxRetries times, doubling the wait between each attempt starting from initialDelay.", + TimeLimit: 3 * time.Minute, + Tags: []string{"go", "retry", "backoff"}, + SetupFn: func(workDir string) error { + if err := helperInitModule(workDir, "retrybackoff"); err != nil { + return err + } + code := `package main + +import "time" + +// Retry retries fn up to maxRetries times with exponential backoff. +// initialDelay is doubled after each failed attempt. +func Retry(fn func() error, maxRetries int, initialDelay time.Duration) error { + var err error + delay := initialDelay + for i := 0; i <= maxRetries; i++ { + err = fn() + if err == nil { + return nil + } + if i < maxRetries { + time.Sleep(delay) + delay *= 2 + } + } + return err +} +` + if err := helperWriteFile(workDir, "main.go", code); err != nil { + return err + } + + test := `package main + +import ( + "errors" + "testing" + "time" +) + +func TestRetrySuccess(t *testing.T) { + calls := 0 + fn := func() error { + calls++ + if calls < 3 { + return errors.New("not yet") + } + return nil + } + + err := Retry(fn, 5, 1*time.Millisecond) + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if calls != 3 { + t.Errorf("expected 3 calls, got %d", calls) + } +} + +func TestRetryExhausted(t *testing.T) { + fn := func() error { + return errors.New("always fails") + } + + err := Retry(fn, 3, 1*time.Millisecond) + if err == nil { + t.Error("expected error after retries exhausted") + } +} + +func TestRetryImmediateSuccess(t *testing.T) { + calls := 0 + fn := func() error { + calls++ + return nil + } + + err := Retry(fn, 3, 1*time.Millisecond) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if calls != 1 { + t.Errorf("expected 1 call, got %d", calls) + } +} +` + return helperWriteFile(workDir, "main_test.go", test) + }, + ValidateFn: helperValidateBuildAndTest, + } +} + +// Task 12: Convert callback to channel pattern +func taskCallbackToChannel() BenchmarkTask { + return BenchmarkTask{ + ID: "go-callback-to-channel", + Description: "Convert a callback-based API to use channels", + Prompt: "Implement the StreamResults function that converts the callback-based FetchWithCallback into a channel-based API. It should return a channel that receives results as they arrive.", + TimeLimit: 2 * time.Minute, + Tags: []string{"go", "channels", "concurrency"}, + SetupFn: func(workDir string) error { + if err := helperInitModule(workDir, "chanpattern"); err != nil { + return err + } + code := `package main + +// FetchWithCallback calls the callback for each result. +func FetchWithCallback(items []string, cb func(string)) { + for _, item := range items { + cb("result: " + item) + } +} + +// StreamResults converts callback-based FetchWithCallback into channel-based API. +func StreamResults(items []string) <-chan string { + ch := make(chan string) + go func() { + defer close(ch) + FetchWithCallback(items, func(result string) { + ch <- result + }) + }() + return ch +} +` + if err := helperWriteFile(workDir, "main.go", code); err != nil { + return err + } + + test := `package main + +import "testing" + +func TestStreamResults(t *testing.T) { + items := []string{"a", "b", "c"} + ch := StreamResults(items) + + var results []string + for r := range ch { + results = append(results, r) + } + + if len(results) != 3 { + t.Fatalf("expected 3 results, got %d", len(results)) + } + expected := []string{"result: a", "result: b", "result: c"} + for i, r := range results { + if r != expected[i] { + t.Errorf("results[%d] = %q, want %q", i, r, expected[i]) + } + } +} + +func TestStreamResultsEmpty(t *testing.T) { + ch := StreamResults(nil) + var results []string + for r := range ch { + results = append(results, r) + } + if len(results) != 0 { + t.Errorf("expected 0 results, got %d", len(results)) + } +} +` + return helperWriteFile(workDir, "main_test.go", test) + }, + ValidateFn: helperValidateBuildAndTest, + } +} + +// Task 13: Add input validation +func taskAddInputValidation() BenchmarkTask { + return BenchmarkTask{ + ID: "go-add-input-validation", + Description: "Add input validation to a user registration function", + Prompt: "Add input validation to the Register function. Validate that: name is non-empty (max 100 chars), email contains '@' and '.', age is between 0 and 150, password is at least 8 chars.", + TimeLimit: 2 * time.Minute, + Tags: []string{"go", "validation"}, + SetupFn: func(workDir string) error { + if err := helperInitModule(workDir, "validation"); err != nil { + return err + } + code := `package main + +import ( + "errors" + "strings" +) + +// Register validates and registers a user. +func Register(name, email string, age int, password string) error { + if name == "" || len(name) > 100 { + return errors.New("invalid name: must be 1-100 characters") + } + if !strings.Contains(email, "@") || !strings.Contains(email, ".") { + return errors.New("invalid email: must contain @ and .") + } + if age < 0 || age > 150 { + return errors.New("invalid age: must be 0-150") + } + if len(password) < 8 { + return errors.New("invalid password: must be at least 8 characters") + } + return nil +} +` + if err := helperWriteFile(workDir, "main.go", code); err != nil { + return err + } + + test := `package main + +import ( + "strings" + "testing" +) + +func TestRegisterValid(t *testing.T) { + err := Register("Alice", "alice@example.com", 25, "securepass") + if err != nil { + t.Errorf("unexpected error: %v", err) + } +} + +func TestRegisterInvalidName(t *testing.T) { + if err := Register("", "a@b.c", 25, "securepass"); err == nil { + t.Error("expected error for empty name") + } + longName := strings.Repeat("a", 101) + if err := Register(longName, "a@b.c", 25, "securepass"); err == nil { + t.Error("expected error for name > 100 chars") + } +} + +func TestRegisterInvalidEmail(t *testing.T) { + if err := Register("Alice", "invalid", 25, "securepass"); err == nil { + t.Error("expected error for email without @") + } + if err := Register("Alice", "no@dot", 25, "securepass"); err == nil { + t.Error("expected error for email without .") + } +} + +func TestRegisterInvalidAge(t *testing.T) { + if err := Register("Alice", "a@b.c", -1, "securepass"); err == nil { + t.Error("expected error for negative age") + } + if err := Register("Alice", "a@b.c", 151, "securepass"); err == nil { + t.Error("expected error for age > 150") + } +} + +func TestRegisterInvalidPassword(t *testing.T) { + if err := Register("Alice", "a@b.c", 25, "short"); err == nil { + t.Error("expected error for password < 8 chars") + } +} +` + return helperWriteFile(workDir, "main_test.go", test) + }, + ValidateFn: helperValidateBuildAndTest, + } +} + +// Task 14: Fix a goroutine leak +func taskFixGoroutineLeak() BenchmarkTask { + return BenchmarkTask{ + ID: "go-fix-goroutine-leak", + Description: "Fix a goroutine leak in a producer function", + Prompt: "Fix the goroutine leak in the Produce function. The goroutine should stop when the done channel is closed, and the returned channel should be properly closed when the goroutine exits.", + TimeLimit: 2 * time.Minute, + Tags: []string{"go", "concurrency", "goroutine-leak"}, + SetupFn: func(workDir string) error { + if err := helperInitModule(workDir, "leakfix"); err != nil { + return err + } + code := `package main + +// Produce generates sequential numbers until done is closed. +func Produce(done <-chan struct{}) <-chan int { + ch := make(chan int) + go func() { + defer close(ch) + i := 0 + for { + select { + case <-done: + return + case ch <- i: + i++ + } + } + }() + return ch +} +` + if err := helperWriteFile(workDir, "main.go", code); err != nil { + return err + } + + test := `package main + +import "testing" + +func TestProduceAndStop(t *testing.T) { + done := make(chan struct{}) + ch := Produce(done) + + // Read a few values. + for i := 0; i < 5; i++ { + val := <-ch + if val != i { + t.Errorf("expected %d, got %d", i, val) + } + } + + // Signal done. + close(done) + + // Channel should eventually be closed. + // Drain any remaining buffered values. + drained := false + for range ch { + drained = true + _ = drained + } + // If we get here, the channel was closed properly. +} + +func TestProduceImmediateStop(t *testing.T) { + done := make(chan struct{}) + close(done) + ch := Produce(done) + + // Channel should be closed without producing values. + count := 0 + for range ch { + count++ + } + // Might get 0 or a very small number. + if count > 1 { + t.Errorf("expected at most 1 value after immediate close, got %d", count) + } +} +` + return helperWriteFile(workDir, "main_test.go", test) + }, + ValidateFn: helperValidateBuildAndTest, + } +} + +// Task 15: Implement a simple HTTP handler +func taskImplementHTTPHandler() BenchmarkTask { + return BenchmarkTask{ + ID: "go-implement-http-handler", + Description: "Implement a simple HTTP handler that returns JSON responses", + Prompt: "Implement the HealthHandler function that returns a JSON response with status 200 and body {\"status\":\"ok\",\"service\":\"hawk\"}. Also implement NotFoundHandler that returns 404 with {\"error\":\"not found\"}.", + TimeLimit: 2 * time.Minute, + Tags: []string{"go", "http", "handler"}, + SetupFn: func(workDir string) error { + if err := helperInitModule(workDir, "httphandler"); err != nil { + return err + } + code := `package main + +import ( + "encoding/json" + "net/http" +) + +// HealthHandler returns a JSON health check response. +func HealthHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{ + "status": "ok", + "service": "hawk", + }) +} + +// NotFoundHandler returns a JSON 404 response. +func NotFoundHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusNotFound) + json.NewEncoder(w).Encode(map[string]string{ + "error": "not found", + }) +} +` + if err := helperWriteFile(workDir, "main.go", code); err != nil { + return err + } + + test := `package main + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestHealthHandler(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/health", nil) + w := httptest.NewRecorder() + + HealthHandler(w, req) + + if w.Code != http.StatusOK { + t.Errorf("status = %d, want %d", w.Code, http.StatusOK) + } + + ct := w.Header().Get("Content-Type") + if ct != "application/json" { + t.Errorf("Content-Type = %q, want %q", ct, "application/json") + } + + var body map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil { + t.Fatalf("failed to parse JSON: %v", err) + } + if body["status"] != "ok" { + t.Errorf("status = %q, want %q", body["status"], "ok") + } + if body["service"] != "hawk" { + t.Errorf("service = %q, want %q", body["service"], "hawk") + } +} + +func TestNotFoundHandler(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/missing", nil) + w := httptest.NewRecorder() + + NotFoundHandler(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("status = %d, want %d", w.Code, http.StatusNotFound) + } + + var body map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil { + t.Fatalf("failed to parse JSON: %v", err) + } + if body["error"] != "not found" { + t.Errorf("error = %q, want %q", body["error"], "not found") + } +} +` + return helperWriteFile(workDir, "main_test.go", test) + }, + ValidateFn: helperValidateBuildAndTest, + } +} From ef27cd4b6854ea8338701d78c4b5d73ac9960f6f Mon Sep 17 00:00:00 2001 From: Lakshman Patel Date: Fri, 19 Jun 2026 19:17:16 +0530 Subject: [PATCH 10/20] refactor(cmd): split struct-based markdown renderer into markdown_renderer.go --- cmd/markdown.go | 761 +-------------------------------------- cmd/markdown_renderer.go | 759 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 763 insertions(+), 757 deletions(-) create mode 100644 cmd/markdown_renderer.go diff --git a/cmd/markdown.go b/cmd/markdown.go index 1764518a..20da863b 100644 --- a/cmd/markdown.go +++ b/cmd/markdown.go @@ -1,11 +1,8 @@ package cmd import ( - "fmt" "regexp" - "strconv" "strings" - "unicode" "github.com/charmbracelet/lipgloss" "github.com/mattn/go-runewidth" @@ -13,6 +10,10 @@ import ( // --------------------------------------------------------------------------- // Legacy markdown rendering using lipgloss (used by chat_view.go) +// +// The struct-based MarkdownRenderer (glamour/glow-inspired) lives in +// markdown_renderer.go. Shared helpers defined here (visibleWidth, stripAnsi, +// reAnsi, isHorizontalRule, parseHeader) are used by both renderers. // --------------------------------------------------------------------------- // Markdown rendering styles using the project's purpose-named palette. @@ -381,757 +382,3 @@ var reAnsi = regexp.MustCompile(`\x1b\[[0-9;]*[a-zA-Z]`) func stripAnsi(s string) string { return reAnsi.ReplaceAllString(s, "") } - -// --------------------------------------------------------------------------- -// Struct-based MarkdownRenderer (glamour/glow-inspired, stdlib-only ANSI) -// --------------------------------------------------------------------------- - -// MarkdownTheme defines ANSI escape codes for styling markdown elements. -type MarkdownTheme struct { - Heading string - Bold string - Italic string - Code string - CodeBlock string - Link string - ListBullet string - BlockQuote string - HorizontalRule string - Reset string -} - -// DefaultTheme returns a visually appealing terminal color theme. -func DefaultTheme() *MarkdownTheme { - return &MarkdownTheme{ - Heading: "\x1b[1;36m", // bold cyan - Bold: "\x1b[1m", // bold - Italic: "\x1b[3m", // italic - Code: "\x1b[48;5;236;37m", // dark bg + cyan fg - CodeBlock: "\x1b[48;5;236m", // dark background - Link: "\x1b[4;36m", // underline cyan - ListBullet: "\x1b[36m", // cyan - BlockQuote: "\x1b[3;90m", // italic dim - HorizontalRule: "\x1b[90m", // dim - Reset: "\x1b[0m", // reset all - } -} - -// MarkdownRenderer renders markdown text to styled ANSI terminal output. -type MarkdownRenderer struct { - Width int - Theme *MarkdownTheme - SyntaxHighlight bool -} - -// NewMarkdownRenderer creates a new renderer with the given terminal width. -func NewMarkdownRenderer(width int) *MarkdownRenderer { - if width <= 0 { - width = 80 - } - return &MarkdownRenderer{ - Width: width, - Theme: DefaultTheme(), - SyntaxHighlight: true, - } -} - -// Compiled regex patterns for the struct-based renderer. -var ( - reRendererBold = regexp.MustCompile(`\*\*(.+?)\*\*`) - reRendererItalic = regexp.MustCompile(`(?:^|[^*])\*([^*]+?)\*(?:[^*]|$)`) - reRendererCode = regexp.MustCompile("`([^`]+)`") - reRendererLink = regexp.MustCompile(`\[([^\]]+)\]\(([^)]+)\)`) - reRendererOrderedLi = regexp.MustCompile(`^(\s*)(\d+)\.\s+(.*)$`) - reRendererTableRow = regexp.MustCompile(`^\|(.+)\|$`) - reRendererTableSep = regexp.MustCompile(`^\|[\s:]*[-]+[\s:]*`) - reHighlightKeyword = regexp.MustCompile(`\b(func|var|const|type|struct|interface|map|chan|go|defer|return|if|else|for|range|switch|case|default|break|continue|select|package|import|nil|true|false|def|class|self|from|import|as|with|yield|lambda|try|except|finally|raise|assert|pass|del|global|nonlocal|async|await|function|let|const|var|new|this|typeof|instanceof|export|import|from|async|await|fn|pub|mod|use|impl|trait|enum|match|loop|move|mut|ref|where|unsafe|extern|crate|macro|then|fi|do|done|elif|esac)\b`) - reHighlightString = regexp.MustCompile(`("(?:[^"\\]|\\.)*"|'(?:[^'\\]|\\.)*'|` + "`" + `[^` + "`" + `]*` + "`" + `)`) - reHighlightComment = regexp.MustCompile(`(//.*$|#.*$|/\*.*?\*/)`) - reHighlightNumber = regexp.MustCompile(`\b(\d+\.?\d*)\b`) -) - -// Render converts a markdown string to ANSI-styled terminal output. -func (r *MarkdownRenderer) Render(markdown string) string { - if markdown == "" { - return "" - } - - theme := r.Theme - if theme == nil { - theme = DefaultTheme() - } - width := r.Width - if width <= 0 { - width = 80 - } - - lines := strings.Split(markdown, "\n") - var result strings.Builder - i := 0 - - for i < len(lines) { - line := lines[i] - trimmed := strings.TrimSpace(line) - - // Fenced code block - if strings.HasPrefix(trimmed, "```") { - lang := strings.TrimSpace(strings.TrimPrefix(trimmed, "```")) - var codeLines []string - i++ - for i < len(lines) { - if strings.TrimSpace(lines[i]) == "```" { - i++ - break - } - codeLines = append(codeLines, lines[i]) - i++ - } - codeContent := strings.Join(codeLines, "\n") - result.WriteString(r.renderFencedCodeBlock(codeContent, lang, width)) - result.WriteByte('\n') - continue - } - - // Table detection - if reRendererTableRow.MatchString(trimmed) { - var tableLines []string - for i < len(lines) && reRendererTableRow.MatchString(strings.TrimSpace(lines[i])) { - tableLines = append(tableLines, strings.TrimSpace(lines[i])) - i++ - } - result.WriteString(r.renderTableFromLines(tableLines, width)) - result.WriteByte('\n') - continue - } - - // Horizontal rule - if isHorizontalRule(trimmed) { - result.WriteString(theme.HorizontalRule) - ruleWidth := width - if ruleWidth > 80 { - ruleWidth = 80 - } - result.WriteString(strings.Repeat("─", ruleWidth)) - result.WriteString(theme.Reset) - result.WriteByte('\n') - i++ - continue - } - - // Headers - if level, text := parseHeader(line); level > 0 { - rendered := r.renderInline(text) - if level == 1 { - result.WriteString(theme.Heading) - result.WriteString("\x1b[4m") // underline for h1 - result.WriteString(rendered) - result.WriteString(theme.Reset) - } else { - result.WriteString(theme.Heading) - result.WriteString(rendered) - result.WriteString(theme.Reset) - } - result.WriteByte('\n') - i++ - continue - } - - // Blockquote - if strings.HasPrefix(trimmed, "> ") || trimmed == ">" { - text := "" - if len(trimmed) > 2 { - text = trimmed[2:] - } - rendered := r.renderInline(text) - wrapped := WrapText(rendered, width-4) - for _, wl := range strings.Split(wrapped, "\n") { - result.WriteString(theme.BlockQuote) - result.WriteString("│ ") - result.WriteString(wl) - result.WriteString(theme.Reset) - result.WriteByte('\n') - } - i++ - continue - } - - // Unordered list - if bullet, text := r.parseListItem(line); bullet != "" { - indent := r.countLeadingSpaces(line) / 2 - indentStr := strings.Repeat(" ", indent) - rendered := r.renderInline(text) - wrapped := WrapText(rendered, width-len(indentStr)-4) - wrapLines := strings.Split(wrapped, "\n") - result.WriteString(indentStr) - result.WriteString(" ") - result.WriteString(theme.ListBullet) - result.WriteString(bullet) - result.WriteString(theme.Reset) - result.WriteString(" ") - result.WriteString(wrapLines[0]) - result.WriteByte('\n') - contIndent := indentStr + " " - for _, wl := range wrapLines[1:] { - result.WriteString(contIndent) - result.WriteString(wl) - result.WriteByte('\n') - } - i++ - continue - } - - // Ordered list - if m := reRendererOrderedLi.FindStringSubmatch(line); m != nil { - indentStr := m[1] - num := m[2] - text := m[3] - rendered := r.renderInline(text) - prefix := num + "." - wrapped := WrapText(rendered, width-len(indentStr)-len(prefix)-3) - wrapLines := strings.Split(wrapped, "\n") - result.WriteString(indentStr) - result.WriteString(" ") - result.WriteString(prefix) - result.WriteString(" ") - result.WriteString(wrapLines[0]) - result.WriteByte('\n') - contIndent := indentStr + strings.Repeat(" ", len(prefix)+3) - for _, wl := range wrapLines[1:] { - result.WriteString(contIndent) - result.WriteString(wl) - result.WriteByte('\n') - } - i++ - continue - } - - // Empty line - if trimmed == "" { - result.WriteByte('\n') - i++ - continue - } - - // Regular paragraph - rendered := r.renderInline(line) - wrapped := WrapText(rendered, width) - result.WriteString(wrapped) - result.WriteByte('\n') - i++ - } - - return strings.TrimRight(result.String(), "\n") -} - -// renderInline applies inline formatting (bold, italic, code, links). -func (r *MarkdownRenderer) renderInline(text string) string { - theme := r.Theme - - // Links - text = reRendererLink.ReplaceAllStringFunc(text, func(m string) string { - parts := reRendererLink.FindStringSubmatch(m) - if len(parts) < 3 { - return m - } - return theme.Link + parts[1] + theme.Reset + " (" + parts[2] + ")" - }) - - // Inline code (before bold/italic) - text = reRendererCode.ReplaceAllStringFunc(text, func(m string) string { - parts := reRendererCode.FindStringSubmatch(m) - if len(parts) < 2 { - return m - } - return theme.Code + parts[1] + theme.Reset - }) - - // Bold - text = reRendererBold.ReplaceAllStringFunc(text, func(m string) string { - parts := reRendererBold.FindStringSubmatch(m) - if len(parts) < 2 { - return m - } - return theme.Bold + parts[1] + theme.Reset - }) - - // Italic - text = reRendererItalic.ReplaceAllStringFunc(text, func(m string) string { - parts := reRendererItalic.FindStringSubmatch(m) - if len(parts) < 2 { - return m - } - prefix := "" - suffix := "" - if len(m) > 0 && m[0] != '*' { - prefix = string(m[0]) - } - if len(m) > 0 && m[len(m)-1] != '*' { - suffix = string(m[len(m)-1]) - } - return prefix + theme.Italic + parts[1] + theme.Reset + suffix - }) - - return text -} - -// parseListItem detects unordered list items with various bullet markers. -func (r *MarkdownRenderer) parseListItem(line string) (string, string) { - trimmed := strings.TrimLeft(line, " \t") - for _, prefix := range []string{"- ", "* ", "+ "} { - if strings.HasPrefix(trimmed, prefix) { - return "•", strings.TrimSpace(trimmed[2:]) - } - } - return "", "" -} - -// countLeadingSpaces returns the number of leading space characters. -func (r *MarkdownRenderer) countLeadingSpaces(line string) int { - count := 0 - for _, ch := range line { - if ch == ' ' { - count++ - } else if ch == '\t' { - count += 2 - } else { - break - } - } - return count -} - -// renderFencedCodeBlock renders a code block with optional syntax highlighting. -func (r *MarkdownRenderer) renderFencedCodeBlock(code, lang string, width int) string { - theme := r.Theme - var b strings.Builder - innerWidth := width - 6 - if innerWidth < 20 { - innerWidth = width - 2 - } - - // Language label - if lang != "" { - b.WriteString(" \x1b[90m") - b.WriteString(" " + lang + " ") - b.WriteString(theme.Reset) - b.WriteByte('\n') - } - - // Optionally syntax highlight - highlighted := code - if r.SyntaxHighlight && lang != "" { - highlighted = HighlightCode(code, lang) - } - - for _, line := range strings.Split(highlighted, "\n") { - b.WriteString(" ") - b.WriteString(theme.CodeBlock) - b.WriteString(" ") - // Pad to innerWidth for consistent background - plain := StripANSI(line) - visW := 0 - for _, r := range plain { - visW += runeWidth(r) - } - b.WriteString(line) - if visW < innerWidth { - b.WriteString(strings.Repeat(" ", innerWidth-visW)) - } - b.WriteString(" ") - b.WriteString(theme.Reset) - b.WriteByte('\n') - } - - return strings.TrimRight(b.String(), "\n") -} - -// renderTableFromLines parses markdown table lines and renders with box-drawing. -func (r *MarkdownRenderer) renderTableFromLines(tableLines []string, width int) string { - if len(tableLines) == 0 { - return "" - } - - // Parse rows - var rows [][]string - for _, line := range tableLines { - // Skip separator lines (e.g., |---|---|) - if reRendererTableSep.MatchString(line) { - continue - } - cells := parseTableRow(line) - if len(cells) > 0 { - rows = append(rows, cells) - } - } - - if len(rows) == 0 { - return "" - } - - return RenderTable(rows) -} - -// parseTableRow splits a table row like "|a|b|c|" into cells. -func parseTableRow(line string) []string { - // Remove leading/trailing | - line = strings.TrimSpace(line) - line = strings.TrimPrefix(line, "|") - line = strings.TrimSuffix(line, "|") - parts := strings.Split(line, "|") - cells := make([]string, len(parts)) - for i, p := range parts { - cells[i] = strings.TrimSpace(p) - } - return cells -} - -// HighlightCode performs regex-based syntax highlighting for common languages. -// Supports Go, Python, JavaScript/TypeScript, Rust, and Bash. -func HighlightCode(code string, language string) string { - lang := strings.ToLower(language) - - // Only highlight supported languages - switch lang { - case "go", "golang", "python", "py", "javascript", "js", "typescript", "ts", "rust", "rs", "bash", "sh", "shell", "zsh": - // proceed - default: - return code - } - - // ANSI color codes for syntax elements - const ( - keywordColor = "\x1b[38;5;198m" // magenta/pink for keywords - stringColor = "\x1b[38;5;113m" // green for strings - commentColor = "\x1b[38;5;242m" // gray for comments - numberColor = "\x1b[38;5;141m" // purple for numbers - resetColor = "\x1b[0m" - ) - - // Process line by line to handle comments correctly - lines := strings.Split(code, "\n") - var result []string - for _, line := range lines { - highlighted := line - - // Comments first (they override everything else on the line) - if loc := reHighlightComment.FindStringIndex(highlighted); loc != nil { - before := highlighted[:loc[0]] - comment := highlighted[loc[0]:loc[1]] - after := highlighted[loc[1]:] - before = highlightNonComment(before, keywordColor, stringColor, numberColor, resetColor) - highlighted = before + commentColor + comment + resetColor + after - } else { - highlighted = highlightNonComment(highlighted, keywordColor, stringColor, numberColor, resetColor) - } - - result = append(result, highlighted) - } - - return strings.Join(result, "\n") -} - -// highlightNonComment highlights keywords, strings, and numbers in non-comment text. -func highlightNonComment(text, keywordColor, stringColor, numberColor, resetColor string) string { - // Strings first (so keywords inside strings are not highlighted) - text = reHighlightString.ReplaceAllStringFunc(text, func(m string) string { - return stringColor + m + resetColor - }) - - // Keywords (only highlight if not inside a string - simplified approach) - text = reHighlightKeyword.ReplaceAllStringFunc(text, func(m string) string { - return keywordColor + m + resetColor - }) - - // Numbers - text = reHighlightNumber.ReplaceAllStringFunc(text, func(m string) string { - // Don't highlight numbers that are part of ANSI escape sequences - return numberColor + m + resetColor - }) - - return text -} - -// WrapText performs word-wrapping at the specified width boundary. -// It respects ANSI escape codes by measuring only visible characters. -func WrapText(text string, width int) string { - if width <= 0 { - width = 80 - } - if text == "" { - return "" - } - - // Quick check: if text already fits, return as-is - plainLen := len(StripANSI(text)) - if plainLen <= width { - return text - } - - var result strings.Builder - words := strings.Fields(text) - curWidth := 0 - - for _, word := range words { - wordW := visibleWidth(word) - if curWidth > 0 && curWidth+1+wordW > width { - result.WriteByte('\n') - result.WriteString(word) - curWidth = wordW - } else if curWidth > 0 { - result.WriteByte(' ') - result.WriteString(word) - curWidth += 1 + wordW - } else { - result.WriteString(word) - curWidth = wordW - } - } - return result.String() -} - -// RenderTable renders a table with box-drawing characters. -// The first row is treated as the header. Column widths are auto-calculated. -func RenderTable(rows [][]string) string { - if len(rows) == 0 { - return "" - } - - // Determine number of columns - numCols := 0 - for _, row := range rows { - if len(row) > numCols { - numCols = len(row) - } - } - if numCols == 0 { - return "" - } - - // Calculate column widths - colWidths := make([]int, numCols) - for _, row := range rows { - for i, cell := range row { - if i < numCols { - w := len(StripANSI(cell)) - if w > colWidths[i] { - colWidths[i] = w - } - } - } - } - - // Ensure minimum width of 3 - for i := range colWidths { - if colWidths[i] < 3 { - colWidths[i] = 3 - } - } - - var b strings.Builder - - // Top border: ┌───┬───┐ - b.WriteString("┌") - for i, w := range colWidths { - b.WriteString(strings.Repeat("─", w+2)) - if i < numCols-1 { - b.WriteString("┬") - } - } - b.WriteString("┐\n") - - for rowIdx, row := range rows { - // Row content: │ cell │ cell │ - b.WriteString("│") - for i := 0; i < numCols; i++ { - cell := "" - if i < len(row) { - cell = row[i] - } - plainCell := StripANSI(cell) - pad := colWidths[i] - len(plainCell) - if pad < 0 { - pad = 0 - } - b.WriteString(" ") - b.WriteString(cell) - b.WriteString(strings.Repeat(" ", pad)) - b.WriteString(" │") - } - b.WriteString("\n") - - // After header row: ├───┼───┤ - if rowIdx == 0 && len(rows) > 1 { - b.WriteString("├") - for i, w := range colWidths { - b.WriteString(strings.Repeat("─", w+2)) - if i < numCols-1 { - b.WriteString("┼") - } - } - b.WriteString("┤\n") - } - } - - // Bottom border: └───┴───┘ - b.WriteString("└") - for i, w := range colWidths { - b.WriteString(strings.Repeat("─", w+2)) - if i < numCols-1 { - b.WriteString("┴") - } - } - b.WriteString("┘") - - return b.String() -} - -// StripANSI removes all ANSI escape codes from a string (for plain output). -func StripANSI(text string) string { - return reAnsi.ReplaceAllString(text, "") -} - -// RenderStreaming takes a channel of raw markdown chunks and returns a channel -// of rendered chunks. It buffers partial markdown elements until they can be -// completely rendered. -func RenderStreaming(ch <-chan string) <-chan string { - out := make(chan string, 16) - - go func() { - defer close(out) - - renderer := NewMarkdownRenderer(80) - var buffer strings.Builder - var lastRendered string - - for chunk := range ch { - buffer.WriteString(chunk) - current := buffer.String() - - // Check if we have incomplete elements that need buffering - if hasIncompleteElement(current) { - // Try to render what we can - safe := findSafeRenderPoint(current) - if safe == "" { - continue // buffer more - } - rendered := renderer.Render(safe) - if rendered != lastRendered { - // Send only the new part - diff := computeStreamDiff(lastRendered, rendered) - if diff != "" { - out <- diff - } - lastRendered = rendered - } - } else { - rendered := renderer.Render(current) - if rendered != lastRendered { - diff := computeStreamDiff(lastRendered, rendered) - if diff != "" { - out <- diff - } - lastRendered = rendered - } - } - } - - // Final flush - final := renderer.Render(buffer.String()) - if final != lastRendered { - diff := computeStreamDiff(lastRendered, final) - if diff != "" { - out <- diff - } - } - }() - - return out -} - -// hasIncompleteElement checks for partial markdown elements that should be buffered. -func hasIncompleteElement(s string) bool { - // Unclosed bold - count := strings.Count(s, "**") - if count%2 != 0 { - return true - } - - // Unclosed inline code - inCode := false - for _, ch := range s { - if ch == '`' { - inCode = !inCode - } - } - if inCode { - return true - } - - // Unclosed fenced code block - fenceCount := 0 - for _, line := range strings.Split(s, "\n") { - if strings.HasPrefix(strings.TrimSpace(line), "```") { - fenceCount++ - } - } - return fenceCount%2 != 0 -} - -// findSafeRenderPoint finds the longest prefix that can be safely rendered. -func findSafeRenderPoint(s string) string { - // Try to find the last complete line - lastNewline := strings.LastIndex(s, "\n") - if lastNewline <= 0 { - return "" - } - - candidate := s[:lastNewline] - // Verify this candidate doesn't have incomplete elements - if !hasIncompleteElement(candidate) { - return candidate - } - - // Try second-to-last newline - prevNewline := strings.LastIndex(candidate, "\n") - if prevNewline > 0 { - candidate = s[:prevNewline] - if !hasIncompleteElement(candidate) { - return candidate - } - } - - return "" -} - -// computeStreamDiff computes what new content to emit given old and new rendered text. -func computeStreamDiff(old, new string) string { - if old == "" { - return new - } - if strings.HasPrefix(new, old) { - return new[len(old):] - } - // Content changed (re-rendering), send full new content with clear - return "\r\x1b[J" + new -} - -// runeWidth returns the display width of a single rune. -func runeWidth(r rune) int { - if r == '\t' { - return 4 - } - if !unicode.IsPrint(r) { - return 0 - } - // Use East Asian width awareness - if unicode.Is(unicode.Han, r) || unicode.Is(unicode.Hangul, r) || unicode.Is(unicode.Katakana, r) || unicode.Is(unicode.Hiragana, r) { - return 2 - } - return 1 -} - -// Ensure fmt and strconv are used (required imports for table rendering) -var ( - _ = fmt.Sprintf - _ = strconv.Itoa -) diff --git a/cmd/markdown_renderer.go b/cmd/markdown_renderer.go new file mode 100644 index 00000000..97e875bc --- /dev/null +++ b/cmd/markdown_renderer.go @@ -0,0 +1,759 @@ +package cmd + +import ( + "regexp" + "strings" + "unicode" +) + +// --------------------------------------------------------------------------- +// Struct-based MarkdownRenderer (glamour/glow-inspired, stdlib-only ANSI) +// +// The legacy lipgloss-based renderer and the shared helpers it defines +// (visibleWidth, stripAnsi, reAnsi, isHorizontalRule, parseHeader) live in +// markdown.go. +// --------------------------------------------------------------------------- + +// MarkdownTheme defines ANSI escape codes for styling markdown elements. +type MarkdownTheme struct { + Heading string + Bold string + Italic string + Code string + CodeBlock string + Link string + ListBullet string + BlockQuote string + HorizontalRule string + Reset string +} + +// DefaultTheme returns a visually appealing terminal color theme. +func DefaultTheme() *MarkdownTheme { + return &MarkdownTheme{ + Heading: "\x1b[1;36m", // bold cyan + Bold: "\x1b[1m", // bold + Italic: "\x1b[3m", // italic + Code: "\x1b[48;5;236;37m", // dark bg + cyan fg + CodeBlock: "\x1b[48;5;236m", // dark background + Link: "\x1b[4;36m", // underline cyan + ListBullet: "\x1b[36m", // cyan + BlockQuote: "\x1b[3;90m", // italic dim + HorizontalRule: "\x1b[90m", // dim + Reset: "\x1b[0m", // reset all + } +} + +// MarkdownRenderer renders markdown text to styled ANSI terminal output. +type MarkdownRenderer struct { + Width int + Theme *MarkdownTheme + SyntaxHighlight bool +} + +// NewMarkdownRenderer creates a new renderer with the given terminal width. +func NewMarkdownRenderer(width int) *MarkdownRenderer { + if width <= 0 { + width = 80 + } + return &MarkdownRenderer{ + Width: width, + Theme: DefaultTheme(), + SyntaxHighlight: true, + } +} + +// Compiled regex patterns for the struct-based renderer. +var ( + reRendererBold = regexp.MustCompile(`\*\*(.+?)\*\*`) + reRendererItalic = regexp.MustCompile(`(?:^|[^*])\*([^*]+?)\*(?:[^*]|$)`) + reRendererCode = regexp.MustCompile("`([^`]+)`") + reRendererLink = regexp.MustCompile(`\[([^\]]+)\]\(([^)]+)\)`) + reRendererOrderedLi = regexp.MustCompile(`^(\s*)(\d+)\.\s+(.*)$`) + reRendererTableRow = regexp.MustCompile(`^\|(.+)\|$`) + reRendererTableSep = regexp.MustCompile(`^\|[\s:]*[-]+[\s:]*`) + reHighlightKeyword = regexp.MustCompile(`\b(func|var|const|type|struct|interface|map|chan|go|defer|return|if|else|for|range|switch|case|default|break|continue|select|package|import|nil|true|false|def|class|self|from|import|as|with|yield|lambda|try|except|finally|raise|assert|pass|del|global|nonlocal|async|await|function|let|const|var|new|this|typeof|instanceof|export|import|from|async|await|fn|pub|mod|use|impl|trait|enum|match|loop|move|mut|ref|where|unsafe|extern|crate|macro|then|fi|do|done|elif|esac)\b`) + reHighlightString = regexp.MustCompile(`("(?:[^"\\]|\\.)*"|'(?:[^'\\]|\\.)*'|` + "`" + `[^` + "`" + `]*` + "`" + `)`) + reHighlightComment = regexp.MustCompile(`(//.*$|#.*$|/\*.*?\*/)`) + reHighlightNumber = regexp.MustCompile(`\b(\d+\.?\d*)\b`) +) + +// Render converts a markdown string to ANSI-styled terminal output. +func (r *MarkdownRenderer) Render(markdown string) string { + if markdown == "" { + return "" + } + + theme := r.Theme + if theme == nil { + theme = DefaultTheme() + } + width := r.Width + if width <= 0 { + width = 80 + } + + lines := strings.Split(markdown, "\n") + var result strings.Builder + i := 0 + + for i < len(lines) { + line := lines[i] + trimmed := strings.TrimSpace(line) + + // Fenced code block + if strings.HasPrefix(trimmed, "```") { + lang := strings.TrimSpace(strings.TrimPrefix(trimmed, "```")) + var codeLines []string + i++ + for i < len(lines) { + if strings.TrimSpace(lines[i]) == "```" { + i++ + break + } + codeLines = append(codeLines, lines[i]) + i++ + } + codeContent := strings.Join(codeLines, "\n") + result.WriteString(r.renderFencedCodeBlock(codeContent, lang, width)) + result.WriteByte('\n') + continue + } + + // Table detection + if reRendererTableRow.MatchString(trimmed) { + var tableLines []string + for i < len(lines) && reRendererTableRow.MatchString(strings.TrimSpace(lines[i])) { + tableLines = append(tableLines, strings.TrimSpace(lines[i])) + i++ + } + result.WriteString(r.renderTableFromLines(tableLines, width)) + result.WriteByte('\n') + continue + } + + // Horizontal rule + if isHorizontalRule(trimmed) { + result.WriteString(theme.HorizontalRule) + ruleWidth := width + if ruleWidth > 80 { + ruleWidth = 80 + } + result.WriteString(strings.Repeat("─", ruleWidth)) + result.WriteString(theme.Reset) + result.WriteByte('\n') + i++ + continue + } + + // Headers + if level, text := parseHeader(line); level > 0 { + rendered := r.renderInline(text) + if level == 1 { + result.WriteString(theme.Heading) + result.WriteString("\x1b[4m") // underline for h1 + result.WriteString(rendered) + result.WriteString(theme.Reset) + } else { + result.WriteString(theme.Heading) + result.WriteString(rendered) + result.WriteString(theme.Reset) + } + result.WriteByte('\n') + i++ + continue + } + + // Blockquote + if strings.HasPrefix(trimmed, "> ") || trimmed == ">" { + text := "" + if len(trimmed) > 2 { + text = trimmed[2:] + } + rendered := r.renderInline(text) + wrapped := WrapText(rendered, width-4) + for _, wl := range strings.Split(wrapped, "\n") { + result.WriteString(theme.BlockQuote) + result.WriteString("│ ") + result.WriteString(wl) + result.WriteString(theme.Reset) + result.WriteByte('\n') + } + i++ + continue + } + + // Unordered list + if bullet, text := r.parseListItem(line); bullet != "" { + indent := r.countLeadingSpaces(line) / 2 + indentStr := strings.Repeat(" ", indent) + rendered := r.renderInline(text) + wrapped := WrapText(rendered, width-len(indentStr)-4) + wrapLines := strings.Split(wrapped, "\n") + result.WriteString(indentStr) + result.WriteString(" ") + result.WriteString(theme.ListBullet) + result.WriteString(bullet) + result.WriteString(theme.Reset) + result.WriteString(" ") + result.WriteString(wrapLines[0]) + result.WriteByte('\n') + contIndent := indentStr + " " + for _, wl := range wrapLines[1:] { + result.WriteString(contIndent) + result.WriteString(wl) + result.WriteByte('\n') + } + i++ + continue + } + + // Ordered list + if m := reRendererOrderedLi.FindStringSubmatch(line); m != nil { + indentStr := m[1] + num := m[2] + text := m[3] + rendered := r.renderInline(text) + prefix := num + "." + wrapped := WrapText(rendered, width-len(indentStr)-len(prefix)-3) + wrapLines := strings.Split(wrapped, "\n") + result.WriteString(indentStr) + result.WriteString(" ") + result.WriteString(prefix) + result.WriteString(" ") + result.WriteString(wrapLines[0]) + result.WriteByte('\n') + contIndent := indentStr + strings.Repeat(" ", len(prefix)+3) + for _, wl := range wrapLines[1:] { + result.WriteString(contIndent) + result.WriteString(wl) + result.WriteByte('\n') + } + i++ + continue + } + + // Empty line + if trimmed == "" { + result.WriteByte('\n') + i++ + continue + } + + // Regular paragraph + rendered := r.renderInline(line) + wrapped := WrapText(rendered, width) + result.WriteString(wrapped) + result.WriteByte('\n') + i++ + } + + return strings.TrimRight(result.String(), "\n") +} + +// renderInline applies inline formatting (bold, italic, code, links). +func (r *MarkdownRenderer) renderInline(text string) string { + theme := r.Theme + + // Links + text = reRendererLink.ReplaceAllStringFunc(text, func(m string) string { + parts := reRendererLink.FindStringSubmatch(m) + if len(parts) < 3 { + return m + } + return theme.Link + parts[1] + theme.Reset + " (" + parts[2] + ")" + }) + + // Inline code (before bold/italic) + text = reRendererCode.ReplaceAllStringFunc(text, func(m string) string { + parts := reRendererCode.FindStringSubmatch(m) + if len(parts) < 2 { + return m + } + return theme.Code + parts[1] + theme.Reset + }) + + // Bold + text = reRendererBold.ReplaceAllStringFunc(text, func(m string) string { + parts := reRendererBold.FindStringSubmatch(m) + if len(parts) < 2 { + return m + } + return theme.Bold + parts[1] + theme.Reset + }) + + // Italic + text = reRendererItalic.ReplaceAllStringFunc(text, func(m string) string { + parts := reRendererItalic.FindStringSubmatch(m) + if len(parts) < 2 { + return m + } + prefix := "" + suffix := "" + if len(m) > 0 && m[0] != '*' { + prefix = string(m[0]) + } + if len(m) > 0 && m[len(m)-1] != '*' { + suffix = string(m[len(m)-1]) + } + return prefix + theme.Italic + parts[1] + theme.Reset + suffix + }) + + return text +} + +// parseListItem detects unordered list items with various bullet markers. +func (r *MarkdownRenderer) parseListItem(line string) (string, string) { + trimmed := strings.TrimLeft(line, " \t") + for _, prefix := range []string{"- ", "* ", "+ "} { + if strings.HasPrefix(trimmed, prefix) { + return "•", strings.TrimSpace(trimmed[2:]) + } + } + return "", "" +} + +// countLeadingSpaces returns the number of leading space characters. +func (r *MarkdownRenderer) countLeadingSpaces(line string) int { + count := 0 + for _, ch := range line { + if ch == ' ' { + count++ + } else if ch == '\t' { + count += 2 + } else { + break + } + } + return count +} + +// renderFencedCodeBlock renders a code block with optional syntax highlighting. +func (r *MarkdownRenderer) renderFencedCodeBlock(code, lang string, width int) string { + theme := r.Theme + var b strings.Builder + innerWidth := width - 6 + if innerWidth < 20 { + innerWidth = width - 2 + } + + // Language label + if lang != "" { + b.WriteString(" \x1b[90m") + b.WriteString(" " + lang + " ") + b.WriteString(theme.Reset) + b.WriteByte('\n') + } + + // Optionally syntax highlight + highlighted := code + if r.SyntaxHighlight && lang != "" { + highlighted = HighlightCode(code, lang) + } + + for _, line := range strings.Split(highlighted, "\n") { + b.WriteString(" ") + b.WriteString(theme.CodeBlock) + b.WriteString(" ") + // Pad to innerWidth for consistent background + plain := StripANSI(line) + visW := 0 + for _, r := range plain { + visW += runeWidth(r) + } + b.WriteString(line) + if visW < innerWidth { + b.WriteString(strings.Repeat(" ", innerWidth-visW)) + } + b.WriteString(" ") + b.WriteString(theme.Reset) + b.WriteByte('\n') + } + + return strings.TrimRight(b.String(), "\n") +} + +// renderTableFromLines parses markdown table lines and renders with box-drawing. +func (r *MarkdownRenderer) renderTableFromLines(tableLines []string, width int) string { + if len(tableLines) == 0 { + return "" + } + + // Parse rows + var rows [][]string + for _, line := range tableLines { + // Skip separator lines (e.g., |---|---|) + if reRendererTableSep.MatchString(line) { + continue + } + cells := parseTableRow(line) + if len(cells) > 0 { + rows = append(rows, cells) + } + } + + if len(rows) == 0 { + return "" + } + + return RenderTable(rows) +} + +// parseTableRow splits a table row like "|a|b|c|" into cells. +func parseTableRow(line string) []string { + // Remove leading/trailing | + line = strings.TrimSpace(line) + line = strings.TrimPrefix(line, "|") + line = strings.TrimSuffix(line, "|") + parts := strings.Split(line, "|") + cells := make([]string, len(parts)) + for i, p := range parts { + cells[i] = strings.TrimSpace(p) + } + return cells +} + +// HighlightCode performs regex-based syntax highlighting for common languages. +// Supports Go, Python, JavaScript/TypeScript, Rust, and Bash. +func HighlightCode(code string, language string) string { + lang := strings.ToLower(language) + + // Only highlight supported languages + switch lang { + case "go", "golang", "python", "py", "javascript", "js", "typescript", "ts", "rust", "rs", "bash", "sh", "shell", "zsh": + // proceed + default: + return code + } + + // ANSI color codes for syntax elements + const ( + keywordColor = "\x1b[38;5;198m" // magenta/pink for keywords + stringColor = "\x1b[38;5;113m" // green for strings + commentColor = "\x1b[38;5;242m" // gray for comments + numberColor = "\x1b[38;5;141m" // purple for numbers + resetColor = "\x1b[0m" + ) + + // Process line by line to handle comments correctly + lines := strings.Split(code, "\n") + var result []string + for _, line := range lines { + highlighted := line + + // Comments first (they override everything else on the line) + if loc := reHighlightComment.FindStringIndex(highlighted); loc != nil { + before := highlighted[:loc[0]] + comment := highlighted[loc[0]:loc[1]] + after := highlighted[loc[1]:] + before = highlightNonComment(before, keywordColor, stringColor, numberColor, resetColor) + highlighted = before + commentColor + comment + resetColor + after + } else { + highlighted = highlightNonComment(highlighted, keywordColor, stringColor, numberColor, resetColor) + } + + result = append(result, highlighted) + } + + return strings.Join(result, "\n") +} + +// highlightNonComment highlights keywords, strings, and numbers in non-comment text. +func highlightNonComment(text, keywordColor, stringColor, numberColor, resetColor string) string { + // Strings first (so keywords inside strings are not highlighted) + text = reHighlightString.ReplaceAllStringFunc(text, func(m string) string { + return stringColor + m + resetColor + }) + + // Keywords (only highlight if not inside a string - simplified approach) + text = reHighlightKeyword.ReplaceAllStringFunc(text, func(m string) string { + return keywordColor + m + resetColor + }) + + // Numbers + text = reHighlightNumber.ReplaceAllStringFunc(text, func(m string) string { + // Don't highlight numbers that are part of ANSI escape sequences + return numberColor + m + resetColor + }) + + return text +} + +// WrapText performs word-wrapping at the specified width boundary. +// It respects ANSI escape codes by measuring only visible characters. +func WrapText(text string, width int) string { + if width <= 0 { + width = 80 + } + if text == "" { + return "" + } + + // Quick check: if text already fits, return as-is + plainLen := len(StripANSI(text)) + if plainLen <= width { + return text + } + + var result strings.Builder + words := strings.Fields(text) + curWidth := 0 + + for _, word := range words { + wordW := visibleWidth(word) + if curWidth > 0 && curWidth+1+wordW > width { + result.WriteByte('\n') + result.WriteString(word) + curWidth = wordW + } else if curWidth > 0 { + result.WriteByte(' ') + result.WriteString(word) + curWidth += 1 + wordW + } else { + result.WriteString(word) + curWidth = wordW + } + } + return result.String() +} + +// RenderTable renders a table with box-drawing characters. +// The first row is treated as the header. Column widths are auto-calculated. +func RenderTable(rows [][]string) string { + if len(rows) == 0 { + return "" + } + + // Determine number of columns + numCols := 0 + for _, row := range rows { + if len(row) > numCols { + numCols = len(row) + } + } + if numCols == 0 { + return "" + } + + // Calculate column widths + colWidths := make([]int, numCols) + for _, row := range rows { + for i, cell := range row { + if i < numCols { + w := len(StripANSI(cell)) + if w > colWidths[i] { + colWidths[i] = w + } + } + } + } + + // Ensure minimum width of 3 + for i := range colWidths { + if colWidths[i] < 3 { + colWidths[i] = 3 + } + } + + var b strings.Builder + + // Top border: ┌───┬───┐ + b.WriteString("┌") + for i, w := range colWidths { + b.WriteString(strings.Repeat("─", w+2)) + if i < numCols-1 { + b.WriteString("┬") + } + } + b.WriteString("┐\n") + + for rowIdx, row := range rows { + // Row content: │ cell │ cell │ + b.WriteString("│") + for i := 0; i < numCols; i++ { + cell := "" + if i < len(row) { + cell = row[i] + } + plainCell := StripANSI(cell) + pad := colWidths[i] - len(plainCell) + if pad < 0 { + pad = 0 + } + b.WriteString(" ") + b.WriteString(cell) + b.WriteString(strings.Repeat(" ", pad)) + b.WriteString(" │") + } + b.WriteString("\n") + + // After header row: ├───┼───┤ + if rowIdx == 0 && len(rows) > 1 { + b.WriteString("├") + for i, w := range colWidths { + b.WriteString(strings.Repeat("─", w+2)) + if i < numCols-1 { + b.WriteString("┼") + } + } + b.WriteString("┤\n") + } + } + + // Bottom border: └───┴───┘ + b.WriteString("└") + for i, w := range colWidths { + b.WriteString(strings.Repeat("─", w+2)) + if i < numCols-1 { + b.WriteString("┴") + } + } + b.WriteString("┘") + + return b.String() +} + +// StripANSI removes all ANSI escape codes from a string (for plain output). +func StripANSI(text string) string { + return reAnsi.ReplaceAllString(text, "") +} + +// RenderStreaming takes a channel of raw markdown chunks and returns a channel +// of rendered chunks. It buffers partial markdown elements until they can be +// completely rendered. +func RenderStreaming(ch <-chan string) <-chan string { + out := make(chan string, 16) + + go func() { + defer close(out) + + renderer := NewMarkdownRenderer(80) + var buffer strings.Builder + var lastRendered string + + for chunk := range ch { + buffer.WriteString(chunk) + current := buffer.String() + + // Check if we have incomplete elements that need buffering + if hasIncompleteElement(current) { + // Try to render what we can + safe := findSafeRenderPoint(current) + if safe == "" { + continue // buffer more + } + rendered := renderer.Render(safe) + if rendered != lastRendered { + // Send only the new part + diff := computeStreamDiff(lastRendered, rendered) + if diff != "" { + out <- diff + } + lastRendered = rendered + } + } else { + rendered := renderer.Render(current) + if rendered != lastRendered { + diff := computeStreamDiff(lastRendered, rendered) + if diff != "" { + out <- diff + } + lastRendered = rendered + } + } + } + + // Final flush + final := renderer.Render(buffer.String()) + if final != lastRendered { + diff := computeStreamDiff(lastRendered, final) + if diff != "" { + out <- diff + } + } + }() + + return out +} + +// hasIncompleteElement checks for partial markdown elements that should be buffered. +func hasIncompleteElement(s string) bool { + // Unclosed bold + count := strings.Count(s, "**") + if count%2 != 0 { + return true + } + + // Unclosed inline code + inCode := false + for _, ch := range s { + if ch == '`' { + inCode = !inCode + } + } + if inCode { + return true + } + + // Unclosed fenced code block + fenceCount := 0 + for _, line := range strings.Split(s, "\n") { + if strings.HasPrefix(strings.TrimSpace(line), "```") { + fenceCount++ + } + } + return fenceCount%2 != 0 +} + +// findSafeRenderPoint finds the longest prefix that can be safely rendered. +func findSafeRenderPoint(s string) string { + // Try to find the last complete line + lastNewline := strings.LastIndex(s, "\n") + if lastNewline <= 0 { + return "" + } + + candidate := s[:lastNewline] + // Verify this candidate doesn't have incomplete elements + if !hasIncompleteElement(candidate) { + return candidate + } + + // Try second-to-last newline + prevNewline := strings.LastIndex(candidate, "\n") + if prevNewline > 0 { + candidate = s[:prevNewline] + if !hasIncompleteElement(candidate) { + return candidate + } + } + + return "" +} + +// computeStreamDiff computes what new content to emit given old and new rendered text. +func computeStreamDiff(old, new string) string { + if old == "" { + return new + } + if strings.HasPrefix(new, old) { + return new[len(old):] + } + // Content changed (re-rendering), send full new content with clear + return "\r\x1b[J" + new +} + +// runeWidth returns the display width of a single rune. +func runeWidth(r rune) int { + if r == '\t' { + return 4 + } + if !unicode.IsPrint(r) { + return 0 + } + // Use East Asian width awareness + if unicode.Is(unicode.Han, r) || unicode.Is(unicode.Hangul, r) || unicode.Is(unicode.Katakana, r) || unicode.Is(unicode.Hiragana, r) { + return 2 + } + return 1 +} From 787b01f568767c55a6da4dfa03f5d4cf6febecde Mon Sep 17 00:00:00 2001 From: Lakshman Patel Date: Fri, 19 Jun 2026 19:19:37 +0530 Subject: [PATCH 11/20] refactor(engine): split semantic diff AST helpers into semantic_diff_helpers.go --- internal/engine/semantic_diff.go | 487 ---------------------- internal/engine/semantic_diff_helpers.go | 498 +++++++++++++++++++++++ 2 files changed, 498 insertions(+), 487 deletions(-) create mode 100644 internal/engine/semantic_diff_helpers.go diff --git a/internal/engine/semantic_diff.go b/internal/engine/semantic_diff.go index 2a092a25..1167944f 100644 --- a/internal/engine/semantic_diff.go +++ b/internal/engine/semantic_diff.go @@ -2,9 +2,6 @@ package engine import ( "fmt" - "go/ast" - "go/parser" - "go/token" "regexp" "sort" "strings" @@ -619,487 +616,3 @@ func CompareSignatures(oldSig, newSig string) *SignatureChange { return change } - -// --- Internal helpers --- - -// parseFunctions extracts exported and unexported function signatures from Go source. -func parseFunctions(content string) map[string]string { - funcs := make(map[string]string) - if content == "" { - return funcs - } - - fset := token.NewFileSet() - // Wrap content in a package declaration if needed - src := ensurePackage(content) - file, err := parser.ParseFile(fset, "", src, parser.AllErrors) - if err != nil { - // Fallback to regex-based parsing - return parseFunctionsRegex(content) - } - - for _, decl := range file.Decls { - fn, ok := decl.(*ast.FuncDecl) - if !ok { - continue - } - - name := fn.Name.Name - if fn.Recv != nil && len(fn.Recv.List) > 0 { - recvType := formatFieldType(fn.Recv.List[0].Type) - name = recvType + "." + fn.Name.Name - } - - sig := formatFuncSignature(fn) - funcs[name] = sig - } - - return funcs -} - -// parseTypes extracts type definitions from Go source. -func parseTypes(content string) map[string]string { - types := make(map[string]string) - if content == "" { - return types - } - - fset := token.NewFileSet() - src := ensurePackage(content) - file, err := parser.ParseFile(fset, "", src, parser.AllErrors) - if err != nil { - return types - } - - for _, decl := range file.Decls { - gen, ok := decl.(*ast.GenDecl) - if !ok || gen.Tok != token.TYPE { - continue - } - for _, spec := range gen.Specs { - ts, ok := spec.(*ast.TypeSpec) - if !ok { - continue - } - types[ts.Name.Name] = formatNode(fset, ts.Type) - } - } - - return types -} - -// parseInterfaces extracts interface type definitions and their methods. -func parseInterfaces(content string) map[string]map[string]bool { - interfaces := make(map[string]map[string]bool) - if content == "" { - return interfaces - } - - fset := token.NewFileSet() - src := ensurePackage(content) - file, err := parser.ParseFile(fset, "", src, parser.AllErrors) - if err != nil { - return interfaces - } - - for _, decl := range file.Decls { - gen, ok := decl.(*ast.GenDecl) - if !ok || gen.Tok != token.TYPE { - continue - } - for _, spec := range gen.Specs { - ts, ok := spec.(*ast.TypeSpec) - if !ok { - continue - } - iface, ok := ts.Type.(*ast.InterfaceType) - if !ok { - continue - } - methods := make(map[string]bool) - if iface.Methods != nil { - for _, m := range iface.Methods.List { - for _, name := range m.Names { - methods[name.Name] = true - } - } - } - interfaces[ts.Name.Name] = methods - } - } - - return interfaces -} - -// parseFuncBodies extracts function bodies as strings keyed by function name. -func parseFuncBodies(content string) map[string]string { - bodies := make(map[string]string) - if content == "" { - return bodies - } - - fset := token.NewFileSet() - src := ensurePackage(content) - file, err := parser.ParseFile(fset, "", src, parser.AllErrors) - if err != nil { - return bodies - } - - for _, decl := range file.Decls { - fn, ok := decl.(*ast.FuncDecl) - if !ok || fn.Body == nil { - continue - } - - name := fn.Name.Name - if fn.Recv != nil && len(fn.Recv.List) > 0 { - recvType := formatFieldType(fn.Recv.List[0].Type) - name = recvType + "." + fn.Name.Name - } - - start := fset.Position(fn.Body.Pos()).Offset - end := fset.Position(fn.Body.End()).Offset - if start >= 0 && end <= len(src) && start < end { - bodies[name] = src[start:end] - } - } - - return bodies -} - -// detectImportChanges identifies added and removed imports. -func detectImportChanges(oldContent, newContent string) []SemanticChange { - var changes []SemanticChange - - oldImports := parseImports(oldContent) - newImports := parseImports(newContent) - - for imp := range newImports { - if !oldImports[imp] { - changes = append(changes, SemanticChange{ - Type: "import_added", - Name: imp, - Description: fmt.Sprintf("Import added: %s", imp), - Breaking: false, - Risk: "low", - }) - } - } - - for imp := range oldImports { - if !newImports[imp] { - changes = append(changes, SemanticChange{ - Type: "import_removed", - Name: imp, - Description: fmt.Sprintf("Import removed: %s", imp), - Breaking: false, - Risk: "low", - }) - } - } - - return changes -} - -// detectAddedFunctions finds functions present in new content but not in old. -func detectAddedFunctions(oldContent, newContent string) []SemanticChange { - var changes []SemanticChange - - oldFuncs := parseFunctions(oldContent) - newFuncs := parseFunctions(newContent) - - for name, sig := range newFuncs { - if _, exists := oldFuncs[name]; !exists { - changes = append(changes, SemanticChange{ - Type: "function_added", - Name: name, - Description: fmt.Sprintf("Function added: %s", sig), - Breaking: false, - Risk: "low", - }) - } - } - - return changes -} - -// parseImports extracts import paths from Go source. -func parseImports(content string) map[string]bool { - imports := make(map[string]bool) - if content == "" { - return imports - } - - fset := token.NewFileSet() - src := ensurePackage(content) - file, err := parser.ParseFile(fset, "", src, parser.ImportsOnly) - if err != nil { - return imports - } - - for _, imp := range file.Imports { - path := strings.Trim(imp.Path.Value, `"`) - imports[path] = true - } - - return imports -} - -// ensurePackage wraps content with a package clause if missing. -func ensurePackage(content string) string { - trimmed := strings.TrimSpace(content) - if strings.HasPrefix(trimmed, "package ") { - return content - } - return "package _semantic_diff_analysis\n\n" + content -} - -// formatFuncSignature formats a function declaration's signature as a string. -func formatFuncSignature(fn *ast.FuncDecl) string { - var sb strings.Builder - - sb.WriteString("func ") - - if fn.Recv != nil && len(fn.Recv.List) > 0 { - sb.WriteString("(") - sb.WriteString(formatFieldType(fn.Recv.List[0].Type)) - sb.WriteString(") ") - } - - sb.WriteString(fn.Name.Name) - sb.WriteString("(") - - if fn.Type.Params != nil { - params := make([]string, 0) - for _, field := range fn.Type.Params.List { - typeName := formatFieldType(field.Type) - if len(field.Names) == 0 { - params = append(params, typeName) - } else { - for _, name := range field.Names { - params = append(params, name.Name+" "+typeName) - } - } - } - sb.WriteString(strings.Join(params, ", ")) - } - - sb.WriteString(")") - - if fn.Type.Results != nil && len(fn.Type.Results.List) > 0 { - results := make([]string, 0) - for _, field := range fn.Type.Results.List { - typeName := formatFieldType(field.Type) - if len(field.Names) == 0 { - results = append(results, typeName) - } else { - for _, name := range field.Names { - results = append(results, name.Name+" "+typeName) - } - } - } - if len(results) == 1 { - sb.WriteString(" " + results[0]) - } else { - sb.WriteString(" (" + strings.Join(results, ", ") + ")") - } - } - - return sb.String() -} - -// formatFieldType formats an AST expression representing a type. -func formatFieldType(expr ast.Expr) string { - switch t := expr.(type) { - case *ast.Ident: - return t.Name - case *ast.StarExpr: - return "*" + formatFieldType(t.X) - case *ast.SelectorExpr: - return formatFieldType(t.X) + "." + t.Sel.Name - case *ast.ArrayType: - if t.Len == nil { - return "[]" + formatFieldType(t.Elt) - } - return "[...]" + formatFieldType(t.Elt) - case *ast.MapType: - return "map[" + formatFieldType(t.Key) + "]" + formatFieldType(t.Value) - case *ast.InterfaceType: - return "interface{}" - case *ast.FuncType: - return "func(...)" - case *ast.Ellipsis: - return "..." + formatFieldType(t.Elt) - case *ast.ChanType: - return "chan " + formatFieldType(t.Value) - default: - return "unknown" - } -} - -// formatNode returns a simple string representation of a type node. -func formatNode(fset *token.FileSet, node ast.Node) string { - switch t := node.(type) { - case *ast.StructType: - var fields []string - if t.Fields != nil { - for _, f := range t.Fields.List { - typeName := formatFieldType(f.Type) - for _, name := range f.Names { - fields = append(fields, name.Name+" "+typeName) - } - if len(f.Names) == 0 { - fields = append(fields, typeName) - } - } - } - return "struct{" + strings.Join(fields, "; ") + "}" - case *ast.InterfaceType: - var methods []string - if t.Methods != nil { - for _, m := range t.Methods.List { - for _, name := range m.Names { - methods = append(methods, name.Name) - } - } - } - return "interface{" + strings.Join(methods, "; ") + "}" - case *ast.Ident: - return t.Name - default: - if expr, ok := node.(ast.Expr); ok { - return formatFieldType(expr) - } - return "unknown" - } -} - -// parseFunctionsRegex is a fallback for when AST parsing fails. -func parseFunctionsRegex(content string) map[string]string { - funcs := make(map[string]string) - re := regexp.MustCompile(`func\s+(?:\([^)]*\)\s+)?(\w+)\s*\([^)]*\)(?:\s*(?:\([^)]*\)|[^\s{]+))?\s*{`) - matches := re.FindAllStringSubmatch(content, -1) - for _, m := range matches { - if len(m) > 1 { - funcs[m[1]] = m[0] - } - } - return funcs -} - -// isExported returns true if the name starts with an uppercase letter. -func isExported(name string) bool { - if name == "" { - return false - } - // Handle receiver.method format - parts := strings.Split(name, ".") - checkName := parts[len(parts)-1] - if checkName == "" { - return false - } - return checkName[0] >= 'A' && checkName[0] <= 'Z' -} - -// countPattern counts regex matches in text. -func countPattern(text, pattern string) int { - re := regexp.MustCompile(pattern) - return len(re.FindAllString(text, -1)) -} - -// extractLoopBounds extracts a hash of loop conditions for comparison. -func extractLoopBounds(body string) string { - re := regexp.MustCompile(`for\s+([^{]+){`) - matches := re.FindAllStringSubmatch(body, -1) - var bounds []string - for _, m := range matches { - if len(m) > 1 { - bounds = append(bounds, strings.TrimSpace(m[1])) - } - } - sort.Strings(bounds) - return strings.Join(bounds, ";") -} - -// extractRoutes finds route patterns and associated handler names in content. -func (sa *SemanticAnalyzer) extractRoutes(content string) map[string][]string { - routes := make(map[string][]string) - lines := strings.Split(content, "\n") - - for _, line := range lines { - for _, pattern := range sa.routePatterns { - matches := pattern.FindStringSubmatch(line) - if len(matches) > 1 { - route := matches[1] - // Extract handler name from the same line - handlerRe := regexp.MustCompile(`(\w+)\s*[,)]`) - handlerMatches := handlerRe.FindAllStringSubmatch(line, -1) - for _, hm := range handlerMatches { - if len(hm) > 1 && hm[1] != "http" && hm[1] != "func" { - routes[route] = append(routes[route], hm[1]) - } - } - } - } - } - - return routes -} - -// formatChangeLine formats a SemanticChange for display in the summary. -func formatChangeLine(c SemanticChange) string { - if c.Description != "" { - return c.Description - } - return c.Name -} - -// extractFuncName extracts the function name from a signature string. -func extractFuncName(sig string) string { - re := regexp.MustCompile(`func\s+(?:\([^)]*\)\s+)?(\w+)`) - matches := re.FindStringSubmatch(sig) - if len(matches) > 1 { - return matches[1] - } - return "" -} - -// extractReceiver extracts the receiver type from a signature string. -func extractReceiver(sig string) string { - re := regexp.MustCompile(`func\s+\(([^)]*)\)`) - matches := re.FindStringSubmatch(sig) - if len(matches) > 1 { - return strings.TrimSpace(matches[1]) - } - return "" -} - -// extractParams extracts parameter list from a function signature. -func extractParams(sig string) []string { - // Find params after the function name - re := regexp.MustCompile(`func\s+(?:\([^)]*\)\s+)?\w+\(([^)]*)\)`) - matches := re.FindStringSubmatch(sig) - if len(matches) < 2 || matches[1] == "" { - return nil - } - params := strings.Split(matches[1], ",") - var result []string - for _, p := range params { - trimmed := strings.TrimSpace(p) - if trimmed != "" { - result = append(result, trimmed) - } - } - return result -} - -// extractReturnType extracts return type(s) from a function signature. -func extractReturnType(sig string) string { - // Find the last closing paren of params, then get everything after - re := regexp.MustCompile(`func\s+(?:\([^)]*\)\s+)?\w+\([^)]*\)\s*(.*)$`) - matches := re.FindStringSubmatch(sig) - if len(matches) > 1 { - return strings.TrimSpace(matches[1]) - } - return "" -} diff --git a/internal/engine/semantic_diff_helpers.go b/internal/engine/semantic_diff_helpers.go new file mode 100644 index 00000000..2b84ddc1 --- /dev/null +++ b/internal/engine/semantic_diff_helpers.go @@ -0,0 +1,498 @@ +package engine + +import ( + "fmt" + "go/ast" + "go/parser" + "go/token" + "regexp" + "sort" + "strings" +) + +// This file holds the internal helpers for semantic diff analysis: AST-based +// parsing of Go source, import/function detection, and signature string +// extraction. The SemanticAnalyzer type and the high-level analysis entry +// points live in semantic_diff.go. + +// parseFunctions extracts exported and unexported function signatures from Go source. +func parseFunctions(content string) map[string]string { + funcs := make(map[string]string) + if content == "" { + return funcs + } + + fset := token.NewFileSet() + // Wrap content in a package declaration if needed + src := ensurePackage(content) + file, err := parser.ParseFile(fset, "", src, parser.AllErrors) + if err != nil { + // Fallback to regex-based parsing + return parseFunctionsRegex(content) + } + + for _, decl := range file.Decls { + fn, ok := decl.(*ast.FuncDecl) + if !ok { + continue + } + + name := fn.Name.Name + if fn.Recv != nil && len(fn.Recv.List) > 0 { + recvType := formatFieldType(fn.Recv.List[0].Type) + name = recvType + "." + fn.Name.Name + } + + sig := formatFuncSignature(fn) + funcs[name] = sig + } + + return funcs +} + +// parseTypes extracts type definitions from Go source. +func parseTypes(content string) map[string]string { + types := make(map[string]string) + if content == "" { + return types + } + + fset := token.NewFileSet() + src := ensurePackage(content) + file, err := parser.ParseFile(fset, "", src, parser.AllErrors) + if err != nil { + return types + } + + for _, decl := range file.Decls { + gen, ok := decl.(*ast.GenDecl) + if !ok || gen.Tok != token.TYPE { + continue + } + for _, spec := range gen.Specs { + ts, ok := spec.(*ast.TypeSpec) + if !ok { + continue + } + types[ts.Name.Name] = formatNode(fset, ts.Type) + } + } + + return types +} + +// parseInterfaces extracts interface type definitions and their methods. +func parseInterfaces(content string) map[string]map[string]bool { + interfaces := make(map[string]map[string]bool) + if content == "" { + return interfaces + } + + fset := token.NewFileSet() + src := ensurePackage(content) + file, err := parser.ParseFile(fset, "", src, parser.AllErrors) + if err != nil { + return interfaces + } + + for _, decl := range file.Decls { + gen, ok := decl.(*ast.GenDecl) + if !ok || gen.Tok != token.TYPE { + continue + } + for _, spec := range gen.Specs { + ts, ok := spec.(*ast.TypeSpec) + if !ok { + continue + } + iface, ok := ts.Type.(*ast.InterfaceType) + if !ok { + continue + } + methods := make(map[string]bool) + if iface.Methods != nil { + for _, m := range iface.Methods.List { + for _, name := range m.Names { + methods[name.Name] = true + } + } + } + interfaces[ts.Name.Name] = methods + } + } + + return interfaces +} + +// parseFuncBodies extracts function bodies as strings keyed by function name. +func parseFuncBodies(content string) map[string]string { + bodies := make(map[string]string) + if content == "" { + return bodies + } + + fset := token.NewFileSet() + src := ensurePackage(content) + file, err := parser.ParseFile(fset, "", src, parser.AllErrors) + if err != nil { + return bodies + } + + for _, decl := range file.Decls { + fn, ok := decl.(*ast.FuncDecl) + if !ok || fn.Body == nil { + continue + } + + name := fn.Name.Name + if fn.Recv != nil && len(fn.Recv.List) > 0 { + recvType := formatFieldType(fn.Recv.List[0].Type) + name = recvType + "." + fn.Name.Name + } + + start := fset.Position(fn.Body.Pos()).Offset + end := fset.Position(fn.Body.End()).Offset + if start >= 0 && end <= len(src) && start < end { + bodies[name] = src[start:end] + } + } + + return bodies +} + +// detectImportChanges identifies added and removed imports. +func detectImportChanges(oldContent, newContent string) []SemanticChange { + var changes []SemanticChange + + oldImports := parseImports(oldContent) + newImports := parseImports(newContent) + + for imp := range newImports { + if !oldImports[imp] { + changes = append(changes, SemanticChange{ + Type: "import_added", + Name: imp, + Description: fmt.Sprintf("Import added: %s", imp), + Breaking: false, + Risk: "low", + }) + } + } + + for imp := range oldImports { + if !newImports[imp] { + changes = append(changes, SemanticChange{ + Type: "import_removed", + Name: imp, + Description: fmt.Sprintf("Import removed: %s", imp), + Breaking: false, + Risk: "low", + }) + } + } + + return changes +} + +// detectAddedFunctions finds functions present in new content but not in old. +func detectAddedFunctions(oldContent, newContent string) []SemanticChange { + var changes []SemanticChange + + oldFuncs := parseFunctions(oldContent) + newFuncs := parseFunctions(newContent) + + for name, sig := range newFuncs { + if _, exists := oldFuncs[name]; !exists { + changes = append(changes, SemanticChange{ + Type: "function_added", + Name: name, + Description: fmt.Sprintf("Function added: %s", sig), + Breaking: false, + Risk: "low", + }) + } + } + + return changes +} + +// parseImports extracts import paths from Go source. +func parseImports(content string) map[string]bool { + imports := make(map[string]bool) + if content == "" { + return imports + } + + fset := token.NewFileSet() + src := ensurePackage(content) + file, err := parser.ParseFile(fset, "", src, parser.ImportsOnly) + if err != nil { + return imports + } + + for _, imp := range file.Imports { + path := strings.Trim(imp.Path.Value, `"`) + imports[path] = true + } + + return imports +} + +// ensurePackage wraps content with a package clause if missing. +func ensurePackage(content string) string { + trimmed := strings.TrimSpace(content) + if strings.HasPrefix(trimmed, "package ") { + return content + } + return "package _semantic_diff_analysis\n\n" + content +} + +// formatFuncSignature formats a function declaration's signature as a string. +func formatFuncSignature(fn *ast.FuncDecl) string { + var sb strings.Builder + + sb.WriteString("func ") + + if fn.Recv != nil && len(fn.Recv.List) > 0 { + sb.WriteString("(") + sb.WriteString(formatFieldType(fn.Recv.List[0].Type)) + sb.WriteString(") ") + } + + sb.WriteString(fn.Name.Name) + sb.WriteString("(") + + if fn.Type.Params != nil { + params := make([]string, 0) + for _, field := range fn.Type.Params.List { + typeName := formatFieldType(field.Type) + if len(field.Names) == 0 { + params = append(params, typeName) + } else { + for _, name := range field.Names { + params = append(params, name.Name+" "+typeName) + } + } + } + sb.WriteString(strings.Join(params, ", ")) + } + + sb.WriteString(")") + + if fn.Type.Results != nil && len(fn.Type.Results.List) > 0 { + results := make([]string, 0) + for _, field := range fn.Type.Results.List { + typeName := formatFieldType(field.Type) + if len(field.Names) == 0 { + results = append(results, typeName) + } else { + for _, name := range field.Names { + results = append(results, name.Name+" "+typeName) + } + } + } + if len(results) == 1 { + sb.WriteString(" " + results[0]) + } else { + sb.WriteString(" (" + strings.Join(results, ", ") + ")") + } + } + + return sb.String() +} + +// formatFieldType formats an AST expression representing a type. +func formatFieldType(expr ast.Expr) string { + switch t := expr.(type) { + case *ast.Ident: + return t.Name + case *ast.StarExpr: + return "*" + formatFieldType(t.X) + case *ast.SelectorExpr: + return formatFieldType(t.X) + "." + t.Sel.Name + case *ast.ArrayType: + if t.Len == nil { + return "[]" + formatFieldType(t.Elt) + } + return "[...]" + formatFieldType(t.Elt) + case *ast.MapType: + return "map[" + formatFieldType(t.Key) + "]" + formatFieldType(t.Value) + case *ast.InterfaceType: + return "interface{}" + case *ast.FuncType: + return "func(...)" + case *ast.Ellipsis: + return "..." + formatFieldType(t.Elt) + case *ast.ChanType: + return "chan " + formatFieldType(t.Value) + default: + return "unknown" + } +} + +// formatNode returns a simple string representation of a type node. +func formatNode(fset *token.FileSet, node ast.Node) string { + switch t := node.(type) { + case *ast.StructType: + var fields []string + if t.Fields != nil { + for _, f := range t.Fields.List { + typeName := formatFieldType(f.Type) + for _, name := range f.Names { + fields = append(fields, name.Name+" "+typeName) + } + if len(f.Names) == 0 { + fields = append(fields, typeName) + } + } + } + return "struct{" + strings.Join(fields, "; ") + "}" + case *ast.InterfaceType: + var methods []string + if t.Methods != nil { + for _, m := range t.Methods.List { + for _, name := range m.Names { + methods = append(methods, name.Name) + } + } + } + return "interface{" + strings.Join(methods, "; ") + "}" + case *ast.Ident: + return t.Name + default: + if expr, ok := node.(ast.Expr); ok { + return formatFieldType(expr) + } + return "unknown" + } +} + +// parseFunctionsRegex is a fallback for when AST parsing fails. +func parseFunctionsRegex(content string) map[string]string { + funcs := make(map[string]string) + re := regexp.MustCompile(`func\s+(?:\([^)]*\)\s+)?(\w+)\s*\([^)]*\)(?:\s*(?:\([^)]*\)|[^\s{]+))?\s*{`) + matches := re.FindAllStringSubmatch(content, -1) + for _, m := range matches { + if len(m) > 1 { + funcs[m[1]] = m[0] + } + } + return funcs +} + +// isExported returns true if the name starts with an uppercase letter. +func isExported(name string) bool { + if name == "" { + return false + } + // Handle receiver.method format + parts := strings.Split(name, ".") + checkName := parts[len(parts)-1] + if checkName == "" { + return false + } + return checkName[0] >= 'A' && checkName[0] <= 'Z' +} + +// countPattern counts regex matches in text. +func countPattern(text, pattern string) int { + re := regexp.MustCompile(pattern) + return len(re.FindAllString(text, -1)) +} + +// extractLoopBounds extracts a hash of loop conditions for comparison. +func extractLoopBounds(body string) string { + re := regexp.MustCompile(`for\s+([^{]+){`) + matches := re.FindAllStringSubmatch(body, -1) + var bounds []string + for _, m := range matches { + if len(m) > 1 { + bounds = append(bounds, strings.TrimSpace(m[1])) + } + } + sort.Strings(bounds) + return strings.Join(bounds, ";") +} + +// extractRoutes finds route patterns and associated handler names in content. +func (sa *SemanticAnalyzer) extractRoutes(content string) map[string][]string { + routes := make(map[string][]string) + lines := strings.Split(content, "\n") + + for _, line := range lines { + for _, pattern := range sa.routePatterns { + matches := pattern.FindStringSubmatch(line) + if len(matches) > 1 { + route := matches[1] + // Extract handler name from the same line + handlerRe := regexp.MustCompile(`(\w+)\s*[,)]`) + handlerMatches := handlerRe.FindAllStringSubmatch(line, -1) + for _, hm := range handlerMatches { + if len(hm) > 1 && hm[1] != "http" && hm[1] != "func" { + routes[route] = append(routes[route], hm[1]) + } + } + } + } + } + + return routes +} + +// formatChangeLine formats a SemanticChange for display in the summary. +func formatChangeLine(c SemanticChange) string { + if c.Description != "" { + return c.Description + } + return c.Name +} + +// extractFuncName extracts the function name from a signature string. +func extractFuncName(sig string) string { + re := regexp.MustCompile(`func\s+(?:\([^)]*\)\s+)?(\w+)`) + matches := re.FindStringSubmatch(sig) + if len(matches) > 1 { + return matches[1] + } + return "" +} + +// extractReceiver extracts the receiver type from a signature string. +func extractReceiver(sig string) string { + re := regexp.MustCompile(`func\s+\(([^)]*)\)`) + matches := re.FindStringSubmatch(sig) + if len(matches) > 1 { + return strings.TrimSpace(matches[1]) + } + return "" +} + +// extractParams extracts parameter list from a function signature. +func extractParams(sig string) []string { + // Find params after the function name + re := regexp.MustCompile(`func\s+(?:\([^)]*\)\s+)?\w+\(([^)]*)\)`) + matches := re.FindStringSubmatch(sig) + if len(matches) < 2 || matches[1] == "" { + return nil + } + params := strings.Split(matches[1], ",") + var result []string + for _, p := range params { + trimmed := strings.TrimSpace(p) + if trimmed != "" { + result = append(result, trimmed) + } + } + return result +} + +// extractReturnType extracts return type(s) from a function signature. +func extractReturnType(sig string) string { + // Find the last closing paren of params, then get everything after + re := regexp.MustCompile(`func\s+(?:\([^)]*\)\s+)?\w+\([^)]*\)\s*(.*)$`) + matches := re.FindStringSubmatch(sig) + if len(matches) > 1 { + return strings.TrimSpace(matches[1]) + } + return "" +} From 173b335a5c59af56a08b854a0e9d5020df85101d Mon Sep 17 00:00:00 2001 From: Lakshman Patel Date: Fri, 19 Jun 2026 19:21:38 +0530 Subject: [PATCH 12/20] refactor(repomap): split summary helpers into summary_helpers.go --- internal/intelligence/repomap/summary.go | 411 +---------------- .../intelligence/repomap/summary_helpers.go | 420 ++++++++++++++++++ 2 files changed, 423 insertions(+), 408 deletions(-) create mode 100644 internal/intelligence/repomap/summary_helpers.go diff --git a/internal/intelligence/repomap/summary.go b/internal/intelligence/repomap/summary.go index 89653503..b267b06b 100644 --- a/internal/intelligence/repomap/summary.go +++ b/internal/intelligence/repomap/summary.go @@ -6,18 +6,14 @@ package repomap import ( - "bufio" - "encoding/json" "fmt" "io/fs" "os" "path/filepath" - "regexp" "sort" "strings" "sync" "time" - "unicode" ) // CodebaseSummary holds the high-level overview of a repository. @@ -641,407 +637,6 @@ func RenderCompact(summary *CodebaseSummary) string { ) } -// ── Helper functions (prefixed to avoid conflicts with other files in package) ── - -func summaryDetectLanguage(projectDir string) string { - counts := map[string]int{} - - _ = filepath.WalkDir(projectDir, func(path string, d fs.DirEntry, err error) error { - if err != nil { - return err - } - if d.IsDir() { - if summarySkipDir(filepath.Base(path)) { - return filepath.SkipDir - } - return nil - } - ext := strings.ToLower(filepath.Ext(path)) - switch ext { - case ".go": - counts["Go"]++ - case ".py": - counts["Python"]++ - case ".ts", ".tsx": - counts["TypeScript"]++ - case ".js", ".jsx": - counts["JavaScript"]++ - case ".rs": - counts["Rust"]++ - case ".java": - counts["Java"]++ - case ".rb": - counts["Ruby"]++ - case ".c", ".h": - counts["C"]++ - case ".cpp", ".cc", ".cxx", ".hpp": - counts["C++"]++ - case ".cs": - counts["C#"]++ - } - return nil - }) - - if len(counts) == 0 { - return "Unknown" - } - - best := "" - bestCount := 0 - for lang, count := range counts { - if count > bestCount { - best = lang - bestCount = count - } - } - return best -} - -func summarySkipDir(name string) bool { - skip := []string{ - ".git", "node_modules", "vendor", "__pycache__", ".venv", "venv", - "dist", "build", ".next", ".nuxt", "target", "bin", "obj", - ".idea", ".vscode", ".DS_Store", ".cache", "coverage", - } - for _, s := range skip { - if name == s { - return true - } - } - return false -} - -func summaryIsSupportedFile(path string) bool { - ext := strings.ToLower(filepath.Ext(path)) - supported := map[string]bool{ - ".go": true, ".py": true, ".ts": true, ".tsx": true, - ".js": true, ".jsx": true, ".rs": true, ".java": true, - ".rb": true, ".c": true, ".h": true, ".cpp": true, - ".cc": true, ".cxx": true, ".hpp": true, ".cs": true, - } - return supported[ext] -} - -func summaryCountFileLines(path string) int { - f, err := os.Open(path) - if err != nil { - return 0 - } - defer func() { _ = f.Close() }() - - count := 0 - scanner := bufio.NewScanner(f) - for scanner.Scan() { - count++ - } - return count -} - -func summaryExtractSymbols(path string) []string { - data, err := os.ReadFile(path) - if err != nil { - return nil - } - src := string(data) - - var symbols []Symbol - ext := strings.ToLower(filepath.Ext(path)) - switch ext { - case ".go": - symbols = parseGo(src) - case ".py": - symbols = parsePython(src) - case ".ts", ".tsx", ".js", ".jsx": - symbols = parseTypeScript(src) - default: - return nil - } - - names := make([]string, 0, len(symbols)) - for _, s := range symbols { - names = append(names, s.Name) - } - return names -} - -func summaryCountPublicSymbols(symbols []string, lang string) int { - count := 0 - for _, s := range symbols { - if summaryIsPublicSymbol(s, lang) { - count++ - } - } - return count -} - -func summaryIsPublicSymbol(name string, lang string) bool { - if name == "" { - return false - } - switch lang { - case "Go": - // In Go, public symbols start with uppercase - return unicode.IsUpper(rune(name[0])) - case "Python": - // In Python, public symbols don't start with underscore - return !strings.HasPrefix(name, "_") - default: - // For JS/TS, we consider exported symbols public (parser already filters) - return unicode.IsUpper(rune(name[0])) || !strings.HasPrefix(name, "_") - } -} - -func summaryExtractImports(path string) []string { - f, err := os.Open(path) - if err != nil { - return nil - } - defer func() { _ = f.Close() }() - - var imports []string - scanner := bufio.NewScanner(f) - ext := strings.ToLower(filepath.Ext(path)) - inBlock := false - - for scanner.Scan() { - line := scanner.Text() - switch ext { - case ".go": - if goImportBlockRe.MatchString(line) { - inBlock = true - continue - } - if inBlock { - if goImportBlockEnd.MatchString(line) { - inBlock = false - continue - } - if m := goImportPathRe.FindStringSubmatch(line); m != nil { - imports = append(imports, m[1]) - } - } else if m := goImportSingleRe.FindStringSubmatch(line); m != nil { - imports = append(imports, m[1]) - } - case ".py": - if m := pyImportRe.FindStringSubmatch(line); m != nil { - imports = append(imports, m[1]) - } else if m := pyFromImportRe.FindStringSubmatch(line); m != nil { - imports = append(imports, m[1]) - } - case ".ts", ".tsx", ".js", ".jsx": - if m := tsImportFromRe.FindStringSubmatch(line); m != nil { - imports = append(imports, m[1]) - } else if m := tsImportBareRe.FindStringSubmatch(line); m != nil { - imports = append(imports, m[1]) - } - } - } - return imports -} - -var ( - summaryGoMainRe = regexp.MustCompile(`^func\s+main\s*\(`) - summaryGoPackageMainRe = regexp.MustCompile(`^package\s+main\b`) -) - -func summaryHasGoMain(path string) bool { - f, err := os.Open(path) - if err != nil { - return false - } - defer func() { _ = f.Close() }() - - hasPackageMain := false - hasFuncMain := false - scanner := bufio.NewScanner(f) - for scanner.Scan() { - line := scanner.Text() - if summaryGoPackageMainRe.MatchString(line) { - hasPackageMain = true - } - if summaryGoMainRe.MatchString(line) { - hasFuncMain = true - } - } - return hasPackageMain && hasFuncMain -} - -var summaryPyMainRe = regexp.MustCompile(`^if\s+__name__\s*==\s*['"]__main__['"]`) - -func summaryHasPythonMain(path string) bool { - f, err := os.Open(path) - if err != nil { - return false - } - defer func() { _ = f.Close() }() - - scanner := bufio.NewScanner(f) - for scanner.Scan() { - if summaryPyMainRe.MatchString(scanner.Text()) { - return true - } - } - return false -} - -func summaryFindJSEntryPoints(packageJSONPath string, projectDir string) []string { - data, err := os.ReadFile(packageJSONPath) - if err != nil { - return nil - } - - var pkg struct { - Main string `json:"main"` - } - if unmarshalErr := json.Unmarshal(data, &pkg); unmarshalErr != nil { - return nil - } - - if pkg.Main == "" { - return nil - } - - dir := filepath.Dir(packageJSONPath) - rel, err := filepath.Rel(projectDir, filepath.Join(dir, pkg.Main)) - if err != nil { - return nil - } - return []string{rel} -} - -func summaryCollectPackageSymbols(projectDir, pkgPath string) []string { - dir := filepath.Join(projectDir, pkgPath) - if pkgPath == summaryProjectRoot { - dir = projectDir - } - - var symbols []string - entries, err := os.ReadDir(dir) - if err != nil { - return nil - } - - for _, entry := range entries { - if entry.IsDir() { - continue - } - path := filepath.Join(dir, entry.Name()) - if summaryIsSupportedFile(path) { - symbols = append(symbols, summaryExtractSymbols(path)...) - } - } - return symbols -} - -func summaryIsConfigFile(path string) bool { - base := strings.ToLower(filepath.Base(path)) - configPatterns := []string{ - "config", "settings", "conf", ".env", "yaml", "yml", "toml", - "makefile", "dockerfile", "docker-compose", - } - for _, p := range configPatterns { - if strings.Contains(base, p) { - return true - } - } - return false -} - -func inferProjectDescription(name string, packages []SummaryPackageInfo, lang string) string { - if len(packages) == 0 { - return fmt.Sprintf("A %s project", lang) - } - - // Look for notable package names to infer purpose - hasAPI := false - hasCLI := false - hasWeb := false - hasEngine := false - - for _, pkg := range packages { - lower := strings.ToLower(pkg.Path) - if strings.Contains(lower, "api") || strings.Contains(lower, "handler") { - hasAPI = true - } - if strings.Contains(lower, "cmd") || strings.Contains(lower, "cli") { - hasCLI = true - } - if strings.Contains(lower, "web") || strings.Contains(lower, "frontend") { - hasWeb = true - } - if strings.Contains(lower, "engine") || strings.Contains(lower, "core") { - hasEngine = true - } - } - - switch { - case hasCLI && hasEngine: - return fmt.Sprintf("A %s CLI application with core engine", lang) - case hasCLI: - return fmt.Sprintf("A %s command-line application", lang) - case hasAPI && hasWeb: - return fmt.Sprintf("A %s full-stack web application", lang) - case hasAPI: - return fmt.Sprintf("A %s API service", lang) - case hasWeb: - return fmt.Sprintf("A %s web application", lang) - default: - return fmt.Sprintf("A %s project with %d packages", lang, len(packages)) - } -} - -func summaryDescribeArchitecture(summary *CodebaseSummary) string { - arch := summary.Architecture - - // Try to describe the layer flow for layered architectures - if arch == "layered" && len(summary.Packages) > 0 { - layers := make([]string, 0, 4) - layerNames := []string{"cmd", "engine", "service", "handler", "tool", "config", "store", "repo"} - for _, ln := range layerNames { - for _, pkg := range summary.Packages { - if strings.Contains(strings.ToLower(pkg.Path), ln) { - layers = append(layers, pkg.Path) - break - } - } - if len(layers) >= 4 { - break - } - } - if len(layers) >= 2 { - return fmt.Sprintf("layered (%s)", strings.Join(layers, " -> ")) - } - } - - return arch -} - -func summaryFormatNumber(n int) string { - if n >= 1000 { - return fmt.Sprintf("%d,%03d", n/1000, n%1000) - } - return fmt.Sprintf("%d", n) -} - -func summaryFormatLOC(loc int) string { - if loc >= 1000000 { - return fmt.Sprintf("%.1fM", float64(loc)/1000000.0) - } - if loc >= 1000 { - return fmt.Sprintf("%dK", loc/1000) - } - return fmt.Sprintf("%d", loc) -} - -func summaryFormatPackageName(name string) string { - if name == "" { - return "Unknown" - } - // Capitalize first letter - return strings.ToUpper(name[:1]) + name[1:] -} - -func summaryEstimateTokens(text string) int { - // Rough estimate: 1 token per 4 characters - return len(text) / 4 -} +// Helper functions for summary generation (language/file detection, symbol and +// import extraction, entry-point detection, formatting) live in +// summary_helpers.go. diff --git a/internal/intelligence/repomap/summary_helpers.go b/internal/intelligence/repomap/summary_helpers.go new file mode 100644 index 00000000..96a6bf51 --- /dev/null +++ b/internal/intelligence/repomap/summary_helpers.go @@ -0,0 +1,420 @@ +package repomap + +import ( + "bufio" + "encoding/json" + "fmt" + "os" + "path/filepath" + "regexp" + "strings" + "unicode" +) + +// This file holds the internal helpers for codebase summary generation: +// language/file detection, symbol and import extraction, entry-point detection, +// and formatting. The CodebaseSummary type, SummaryGenerator, and the public +// Render/Infer/Find entry points live in summary.go. + +func summaryDetectLanguage(projectDir string) string { + counts := map[string]int{} + + _ = filepath.WalkDir(projectDir, func(path string, d os.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() { + if summarySkipDir(filepath.Base(path)) { + return filepath.SkipDir + } + return nil + } + ext := strings.ToLower(filepath.Ext(path)) + switch ext { + case ".go": + counts["Go"]++ + case ".py": + counts["Python"]++ + case ".ts", ".tsx": + counts["TypeScript"]++ + case ".js", ".jsx": + counts["JavaScript"]++ + case ".rs": + counts["Rust"]++ + case ".java": + counts["Java"]++ + case ".rb": + counts["Ruby"]++ + case ".c", ".h": + counts["C"]++ + case ".cpp", ".cc", ".cxx", ".hpp": + counts["C++"]++ + case ".cs": + counts["C#"]++ + } + return nil + }) + + if len(counts) == 0 { + return "Unknown" + } + + best := "" + bestCount := 0 + for lang, count := range counts { + if count > bestCount { + best = lang + bestCount = count + } + } + return best +} + +func summarySkipDir(name string) bool { + skip := []string{ + ".git", "node_modules", "vendor", "__pycache__", ".venv", "venv", + "dist", "build", ".next", ".nuxt", "target", "bin", "obj", + ".idea", ".vscode", ".DS_Store", ".cache", "coverage", + } + for _, s := range skip { + if name == s { + return true + } + } + return false +} + +func summaryIsSupportedFile(path string) bool { + ext := strings.ToLower(filepath.Ext(path)) + supported := map[string]bool{ + ".go": true, ".py": true, ".ts": true, ".tsx": true, + ".js": true, ".jsx": true, ".rs": true, ".java": true, + ".rb": true, ".c": true, ".h": true, ".cpp": true, + ".cc": true, ".cxx": true, ".hpp": true, ".cs": true, + } + return supported[ext] +} + +func summaryCountFileLines(path string) int { + f, err := os.Open(path) + if err != nil { + return 0 + } + defer func() { _ = f.Close() }() + + count := 0 + scanner := bufio.NewScanner(f) + for scanner.Scan() { + count++ + } + return count +} + +func summaryExtractSymbols(path string) []string { + data, err := os.ReadFile(path) + if err != nil { + return nil + } + src := string(data) + + var symbols []Symbol + ext := strings.ToLower(filepath.Ext(path)) + switch ext { + case ".go": + symbols = parseGo(src) + case ".py": + symbols = parsePython(src) + case ".ts", ".tsx", ".js", ".jsx": + symbols = parseTypeScript(src) + default: + return nil + } + + names := make([]string, 0, len(symbols)) + for _, s := range symbols { + names = append(names, s.Name) + } + return names +} + +func summaryCountPublicSymbols(symbols []string, lang string) int { + count := 0 + for _, s := range symbols { + if summaryIsPublicSymbol(s, lang) { + count++ + } + } + return count +} + +func summaryIsPublicSymbol(name string, lang string) bool { + if name == "" { + return false + } + switch lang { + case "Go": + // In Go, public symbols start with uppercase + return unicode.IsUpper(rune(name[0])) + case "Python": + // In Python, public symbols don't start with underscore + return !strings.HasPrefix(name, "_") + default: + // For JS/TS, we consider exported symbols public (parser already filters) + return unicode.IsUpper(rune(name[0])) || !strings.HasPrefix(name, "_") + } +} + +func summaryExtractImports(path string) []string { + f, err := os.Open(path) + if err != nil { + return nil + } + defer func() { _ = f.Close() }() + + var imports []string + scanner := bufio.NewScanner(f) + ext := strings.ToLower(filepath.Ext(path)) + inBlock := false + + for scanner.Scan() { + line := scanner.Text() + switch ext { + case ".go": + if goImportBlockRe.MatchString(line) { + inBlock = true + continue + } + if inBlock { + if goImportBlockEnd.MatchString(line) { + inBlock = false + continue + } + if m := goImportPathRe.FindStringSubmatch(line); m != nil { + imports = append(imports, m[1]) + } + } else if m := goImportSingleRe.FindStringSubmatch(line); m != nil { + imports = append(imports, m[1]) + } + case ".py": + if m := pyImportRe.FindStringSubmatch(line); m != nil { + imports = append(imports, m[1]) + } else if m := pyFromImportRe.FindStringSubmatch(line); m != nil { + imports = append(imports, m[1]) + } + case ".ts", ".tsx", ".js", ".jsx": + if m := tsImportFromRe.FindStringSubmatch(line); m != nil { + imports = append(imports, m[1]) + } else if m := tsImportBareRe.FindStringSubmatch(line); m != nil { + imports = append(imports, m[1]) + } + } + } + return imports +} + +var ( + summaryGoMainRe = regexp.MustCompile(`^func\s+main\s*\(`) + summaryGoPackageMainRe = regexp.MustCompile(`^package\s+main\b`) +) + +func summaryHasGoMain(path string) bool { + f, err := os.Open(path) + if err != nil { + return false + } + defer func() { _ = f.Close() }() + + hasPackageMain := false + hasFuncMain := false + scanner := bufio.NewScanner(f) + for scanner.Scan() { + line := scanner.Text() + if summaryGoPackageMainRe.MatchString(line) { + hasPackageMain = true + } + if summaryGoMainRe.MatchString(line) { + hasFuncMain = true + } + } + return hasPackageMain && hasFuncMain +} + +var summaryPyMainRe = regexp.MustCompile(`^if\s+__name__\s*==\s*['"]__main__['"]`) + +func summaryHasPythonMain(path string) bool { + f, err := os.Open(path) + if err != nil { + return false + } + defer func() { _ = f.Close() }() + + scanner := bufio.NewScanner(f) + for scanner.Scan() { + if summaryPyMainRe.MatchString(scanner.Text()) { + return true + } + } + return false +} + +func summaryFindJSEntryPoints(packageJSONPath string, projectDir string) []string { + data, err := os.ReadFile(packageJSONPath) + if err != nil { + return nil + } + + var pkg struct { + Main string `json:"main"` + } + if unmarshalErr := json.Unmarshal(data, &pkg); unmarshalErr != nil { + return nil + } + + if pkg.Main == "" { + return nil + } + + dir := filepath.Dir(packageJSONPath) + rel, err := filepath.Rel(projectDir, filepath.Join(dir, pkg.Main)) + if err != nil { + return nil + } + return []string{rel} +} + +func summaryCollectPackageSymbols(projectDir, pkgPath string) []string { + dir := filepath.Join(projectDir, pkgPath) + if pkgPath == summaryProjectRoot { + dir = projectDir + } + + var symbols []string + entries, err := os.ReadDir(dir) + if err != nil { + return nil + } + + for _, entry := range entries { + if entry.IsDir() { + continue + } + path := filepath.Join(dir, entry.Name()) + if summaryIsSupportedFile(path) { + symbols = append(symbols, summaryExtractSymbols(path)...) + } + } + return symbols +} + +func summaryIsConfigFile(path string) bool { + base := strings.ToLower(filepath.Base(path)) + configPatterns := []string{ + "config", "settings", "conf", ".env", "yaml", "yml", "toml", + "makefile", "dockerfile", "docker-compose", + } + for _, p := range configPatterns { + if strings.Contains(base, p) { + return true + } + } + return false +} + +func inferProjectDescription(name string, packages []SummaryPackageInfo, lang string) string { + if len(packages) == 0 { + return fmt.Sprintf("A %s project", lang) + } + + // Look for notable package names to infer purpose + hasAPI := false + hasCLI := false + hasWeb := false + hasEngine := false + + for _, pkg := range packages { + lower := strings.ToLower(pkg.Path) + if strings.Contains(lower, "api") || strings.Contains(lower, "handler") { + hasAPI = true + } + if strings.Contains(lower, "cmd") || strings.Contains(lower, "cli") { + hasCLI = true + } + if strings.Contains(lower, "web") || strings.Contains(lower, "frontend") { + hasWeb = true + } + if strings.Contains(lower, "engine") || strings.Contains(lower, "core") { + hasEngine = true + } + } + + switch { + case hasCLI && hasEngine: + return fmt.Sprintf("A %s CLI application with core engine", lang) + case hasCLI: + return fmt.Sprintf("A %s command-line application", lang) + case hasAPI && hasWeb: + return fmt.Sprintf("A %s full-stack web application", lang) + case hasAPI: + return fmt.Sprintf("A %s API service", lang) + case hasWeb: + return fmt.Sprintf("A %s web application", lang) + default: + return fmt.Sprintf("A %s project with %d packages", lang, len(packages)) + } +} + +func summaryDescribeArchitecture(summary *CodebaseSummary) string { + arch := summary.Architecture + + // Try to describe the layer flow for layered architectures + if arch == "layered" && len(summary.Packages) > 0 { + layers := make([]string, 0, 4) + layerNames := []string{"cmd", "engine", "service", "handler", "tool", "config", "store", "repo"} + for _, ln := range layerNames { + for _, pkg := range summary.Packages { + if strings.Contains(strings.ToLower(pkg.Path), ln) { + layers = append(layers, pkg.Path) + break + } + } + if len(layers) >= 4 { + break + } + } + if len(layers) >= 2 { + return fmt.Sprintf("layered (%s)", strings.Join(layers, " -> ")) + } + } + + return arch +} + +func summaryFormatNumber(n int) string { + if n >= 1000 { + return fmt.Sprintf("%d,%03d", n/1000, n%1000) + } + return fmt.Sprintf("%d", n) +} + +func summaryFormatLOC(loc int) string { + if loc >= 1000000 { + return fmt.Sprintf("%.1fM", float64(loc)/1000000.0) + } + if loc >= 1000 { + return fmt.Sprintf("%dK", loc/1000) + } + return fmt.Sprintf("%d", loc) +} + +func summaryFormatPackageName(name string) string { + if name == "" { + return "Unknown" + } + // Capitalize first letter + return strings.ToUpper(name[:1]) + name[1:] +} + +func summaryEstimateTokens(text string) int { + // Rough estimate: 1 token per 4 characters + return len(text) / 4 +} From 8c8ecc98890c9c4ef0e74ee6a44976987cda1309 Mon Sep 17 00:00:00 2001 From: Lakshman Patel Date: Fri, 19 Jun 2026 19:23:50 +0530 Subject: [PATCH 13/20] refactor(engine/code): split code explainer helpers into code_explainer_helpers.go --- internal/engine/code/code_explainer.go | 426 ----------------- .../engine/code/code_explainer_helpers.go | 440 ++++++++++++++++++ 2 files changed, 440 insertions(+), 426 deletions(-) create mode 100644 internal/engine/code/code_explainer_helpers.go diff --git a/internal/engine/code/code_explainer.go b/internal/engine/code/code_explainer.go index fba30b97..531c552d 100644 --- a/internal/engine/code/code_explainer.go +++ b/internal/engine/code/code_explainer.go @@ -608,429 +608,3 @@ func FormatExplanation(exp *CodeExplanation) string { return sb.String() } - -func explainerExtractParams(fd *ast.FuncDecl) [][2]string { - var params [][2]string - if fd.Type.Params == nil { - return params - } - for _, field := range fd.Type.Params.List { - typeStr := explainerExprToString(field.Type) - if len(field.Names) == 0 { - params = append(params, [2]string{"", typeStr}) - } - for _, name := range field.Names { - params = append(params, [2]string{name.Name, typeStr}) - } - } - return params -} - -func extractReturns(fd *ast.FuncDecl) []string { - var returns []string - if fd.Type.Results == nil { - return returns - } - for _, field := range fd.Type.Results.List { - typeStr := explainerExprToString(field.Type) - if len(field.Names) > 0 { - for range field.Names { - returns = append(returns, typeStr) - } - } else { - returns = append(returns, typeStr) - } - } - return returns -} - -func explainerExtractDocComment(fd *ast.FuncDecl) string { - if fd.Doc == nil { - return "" - } - text := strings.TrimSpace(fd.Doc.Text()) - if strings.HasPrefix(text, fd.Name.Name+" ") { - text = text[len(fd.Name.Name)+1:] - } - if idx := strings.Index(text, ". "); idx > 0 { - return text[:idx+1] - } - return text -} - -func extractFuncBody(content string, fd *ast.FuncDecl, fset *token.FileSet) string { - if fd.Body == nil { - return "" - } - start := fset.Position(fd.Body.Lbrace).Offset - end := fset.Position(fd.Body.Rbrace).Offset - if start >= 0 && end > start && end <= len(content) { - return content[start:end] - } - return "" -} - -func explainerExprToString(expr ast.Expr) string { - switch e := expr.(type) { - case *ast.Ident: - return e.Name - case *ast.SelectorExpr: - return explainerExprToString(e.X) + "." + e.Sel.Name - case *ast.StarExpr: - return "*" + explainerExprToString(e.X) - case *ast.ArrayType: - return "[]" + explainerExprToString(e.Elt) - case *ast.MapType: - return "map[" + explainerExprToString(e.Key) + "]" + explainerExprToString(e.Value) - case *ast.InterfaceType: - return "interface{}" - case *ast.FuncType: - return "func(...)" - case *ast.ChanType: - return "chan " + explainerExprToString(e.Value) - case *ast.Ellipsis: - return "..." + explainerExprToString(e.Elt) - default: - return "unknown" - } -} - -func inferParamPurpose(name, typeName string) string { - lower := strings.ToLower(name) - switch { - case lower == "ctx" || typeName == "context.Context": - return "Context for cancellation and deadlines" - case lower == "err" || typeName == "error": - return "Error to handle" - case strings.Contains(lower, "path") || strings.Contains(lower, "file"): - return "File path" - case strings.Contains(lower, "name"): - return "Name identifier" - case strings.Contains(lower, "id"): - return "Unique identifier" - case strings.Contains(lower, "timeout") || strings.Contains(lower, "duration"): - return "Time duration" - case strings.Contains(lower, "config") || strings.Contains(lower, "opts"): - return "Configuration options" - case strings.Contains(lower, "fn") || strings.Contains(lower, "func") || strings.Contains(lower, "callback"): - return "Callback function" - case strings.Contains(lower, "buf") || strings.Contains(lower, "data") || strings.Contains(lower, "bytes"): - return "Raw data buffer" - case strings.Contains(lower, "url") || strings.Contains(lower, "addr"): - return "Network address" - case strings.Contains(lower, "token"): - return "Authentication or parsing token" - case strings.Contains(lower, "key"): - return "Lookup key" - case strings.Contains(lower, "val") || strings.Contains(lower, "value"): - return "Value to process" - case typeName == "string": - return fmt.Sprintf("The %s string", name) - case typeName == "int" || typeName == "int64": - return fmt.Sprintf("The %s count or index", name) - case typeName == "bool": - return fmt.Sprintf("Whether to enable %s", name) - default: - return fmt.Sprintf("The %s to use", name) - } -} - -func inferTypePurpose(name string) string { - words := splitCamelCase(name) - if len(words) == 0 { - return "A type" - } - last := strings.ToLower(words[len(words)-1]) - prefix := "" - if len(words) > 1 { - prefix = strings.Join(words[:len(words)-1], " ") - } - - switch last { - case "config", "options", "opts", "settings": - return fmt.Sprintf("Configuration for %s", lowerFirst(prefix)) - case "handler": - return fmt.Sprintf("Handles %s operations", lowerFirst(prefix)) - case "service", "server": - return fmt.Sprintf("Provides %s functionality", lowerFirst(prefix)) - case "client": - return fmt.Sprintf("Client for communicating with %s", lowerFirst(prefix)) - case "store", "repository", "repo": - return fmt.Sprintf("Persistent storage for %s", lowerFirst(prefix)) - case "manager": - return fmt.Sprintf("Manages lifecycle of %s", lowerFirst(prefix)) - case "builder": - return fmt.Sprintf("Builder pattern for constructing %s", lowerFirst(prefix)) - case "error", "err": - return fmt.Sprintf("Error type for %s failures", lowerFirst(prefix)) - case "result": - return fmt.Sprintf("Result of %s operation", lowerFirst(prefix)) - case "request", "req": - return fmt.Sprintf("Request payload for %s", lowerFirst(prefix)) - case "response", "resp": - return fmt.Sprintf("Response from %s", lowerFirst(prefix)) - default: - return fmt.Sprintf("Represents %s", lowerFirst(strings.Join(words, " "))) - } -} - -func inferFieldPurpose(name, typeName string) string { - lower := strings.ToLower(name) - switch { - case lower == "mu" || strings.Contains(typeName, "Mutex"): - return "Protects concurrent access" - case lower == "ctx": - return "Context for cancellation" - case lower == "id": - return "Unique identifier" - case strings.Contains(lower, "err"): - return "Last error encountered" - case strings.Contains(lower, "done") || strings.Contains(lower, "closed"): - return "Signals completion or shutdown" - case strings.Contains(lower, "count") || strings.Contains(lower, "num"): - return "Counter value" - case strings.Contains(lower, "max") || strings.Contains(lower, "limit"): - return "Upper bound constraint" - case strings.Contains(lower, "min"): - return "Lower bound constraint" - case strings.Contains(lower, "timeout"): - return "Maximum wait duration" - case strings.Contains(lower, "name"): - return "Human-readable name" - case strings.Contains(lower, "path") || strings.Contains(lower, "dir"): - return "Filesystem path" - default: - return fmt.Sprintf("The %s value", name) - } -} - -func splitCamelCase(s string) []string { - var words []string - current := strings.Builder{} - for i, r := range s { - if i > 0 && r >= 'A' && r <= 'Z' { - if current.Len() > 0 { - words = append(words, current.String()) - current.Reset() - } - } - current.WriteRune(r) - } - if current.Len() > 0 { - words = append(words, current.String()) - } - return words -} - -func lowerFirst(s string) string { - if s == "" { - return s - } - return strings.ToLower(s[:1]) + s[1:] -} - -func containsType(types []string, target string) bool { - for _, t := range types { - if t == target || strings.Contains(t, target) { - return true - } - } - return false -} - -func computeCyclomaticComplexity(body string) int { - cc := 1 - patterns := []string{ - `\bif\b`, `\belse if\b`, `\bfor\b`, `\bcase\b`, - `&&`, `\|\|`, `\bselect\b`, - } - for _, p := range patterns { - re := regexp.MustCompile(p) - cc += len(re.FindAllString(body, -1)) - } - return cc -} - -func classifyComplexity(cc int) string { - switch { - case cc <= 5: - return "Low" - case cc <= 10: - return "Moderate" - case cc <= 20: - return "High" - default: - return "Very High" - } -} - -func describeErrorHandling(body string) string { - hasErrCheck := regexp.MustCompile(`if\s+err\s*!=\s*nil`).MatchString(body) - hasWrap := regexp.MustCompile(`fmt\.Errorf\(.*%w`).MatchString(body) - hasPanic := regexp.MustCompile(`\bpanic\(`).MatchString(body) - hasRecover := regexp.MustCompile(`\brecover\(\)`).MatchString(body) - - if !hasErrCheck && !hasPanic { - return "No explicit error handling" - } - - var parts []string - if hasErrCheck && hasWrap { - parts = append(parts, "Returns wrapped errors with context") - } else if hasErrCheck { - parts = append(parts, "Checks and propagates errors") - } - if hasPanic { - parts = append(parts, "May panic on unrecoverable errors") - } - if hasRecover { - parts = append(parts, "Includes panic recovery") - } - - if len(parts) == 0 { - return "Basic error checking" - } - return strings.Join(parts, "; ") -} - -func extractDependencies(body string) []string { - var deps []string - re := regexp.MustCompile(`\b([a-z][a-z0-9]+)\.\w+`) - matches := re.FindAllStringSubmatch(body, -1) - seen := map[string]bool{} - for _, m := range matches { - pkg := m[1] - if !seen[pkg] && pkg != "err" && pkg != "nil" { - seen[pkg] = true - deps = append(deps, pkg) - } - } - return deps -} - -func funcSignature(fd *ast.FuncDecl) string { - var sb strings.Builder - sb.WriteString(fd.Name.Name) - sb.WriteString("(") - if fd.Type.Params != nil { - var params []string - for _, field := range fd.Type.Params.List { - typeStr := explainerExprToString(field.Type) - if len(field.Names) > 0 { - for _, name := range field.Names { - params = append(params, name.Name+" "+typeStr) - } - } else { - params = append(params, typeStr) - } - } - sb.WriteString(strings.Join(params, ", ")) - } - sb.WriteString(")") - if fd.Type.Results != nil && len(fd.Type.Results.List) > 0 { - var results []string - for _, field := range fd.Type.Results.List { - results = append(results, explainerExprToString(field.Type)) - } - if len(results) == 1 { - sb.WriteString(" " + results[0]) - } else { - sb.WriteString(" (" + strings.Join(results, ", ") + ")") - } - } - return sb.String() -} - -func findConstructor(f *ast.File, typeName string) string { - candidates := []string{ - "New" + typeName, - "new" + typeName, - } - for _, decl := range f.Decls { - fd, ok := decl.(*ast.FuncDecl) - if !ok || fd.Recv != nil { - continue - } - for _, c := range candidates { - if fd.Name.Name == c { - return c - } - } - } - return "" -} - -func detectImplementedInterfaces(f *ast.File, typeName string, methods []string) []string { - var ifaces []string - methodSet := map[string]bool{} - for _, m := range methods { - methodSet[m] = true - } - - if methodSet["String"] { - ifaces = append(ifaces, "fmt.Stringer") - } - if methodSet["Error"] { - ifaces = append(ifaces, "error") - } - if methodSet["Read"] { - ifaces = append(ifaces, "io.Reader") - } - if methodSet["Write"] { - ifaces = append(ifaces, "io.Writer") - } - if methodSet["Close"] { - ifaces = append(ifaces, "io.Closer") - } - if methodSet["ServeHTTP"] { - ifaces = append(ifaces, "http.Handler") - } - if methodSet["MarshalJSON"] { - ifaces = append(ifaces, "json.Marshaler") - } - if methodSet["UnmarshalJSON"] { - ifaces = append(ifaces, "json.Unmarshaler") - } - if methodSet["Len"] && methodSet["Less"] && methodSet["Swap"] { - ifaces = append(ifaces, "sort.Interface") - } - - return ifaces -} - -func detectPatterns(content string) []string { - var patterns []string - - if regexp.MustCompile(`sync\.(Mutex|RWMutex)`).MatchString(content) { - patterns = append(patterns, "- Mutex-based concurrency control") - } - if regexp.MustCompile(`sync\.Once`).MatchString(content) { - patterns = append(patterns, "- Singleton/once initialization") - } - if regexp.MustCompile(`chan\s+\w`).MatchString(content) { - patterns = append(patterns, "- Channel-based communication") - } - if regexp.MustCompile(`context\.Context`).MatchString(content) { - patterns = append(patterns, "- Context propagation for cancellation") - } - if regexp.MustCompile(`interface\s*\{`).MatchString(content) { - patterns = append(patterns, "- Interface-based abstraction") - } - if regexp.MustCompile(`func\s+New\w+\(`).MatchString(content) { - patterns = append(patterns, "- Constructor functions (New* pattern)") - } - if regexp.MustCompile(`defer\s+`).MatchString(content) { - patterns = append(patterns, "- Deferred cleanup") - } - if regexp.MustCompile(`fmt\.Errorf\(.*%w`).MatchString(content) { - patterns = append(patterns, "- Error wrapping with context") - } - if regexp.MustCompile(`select\s*\{`).MatchString(content) { - patterns = append(patterns, "- Select-based multiplexing") - } - if regexp.MustCompile(`type\s+\w+\s+struct\s*\{[^}]*\w+\s+interface`).MatchString(content) { - patterns = append(patterns, "- Dependency injection via interfaces") - } - - return patterns -} diff --git a/internal/engine/code/code_explainer_helpers.go b/internal/engine/code/code_explainer_helpers.go new file mode 100644 index 00000000..7c92fdf6 --- /dev/null +++ b/internal/engine/code/code_explainer_helpers.go @@ -0,0 +1,440 @@ +package code + +import ( + "fmt" + "go/ast" + "go/token" + "regexp" + "strings" +) + +// This file holds the internal helpers for the CodeExplainer: AST extraction, +// purpose inference, complexity/error-handling analysis, and pattern detection. +// The CodeExplainer type and its public Explain*/Infer/Describe/Detect/Format +// entry points live in code_explainer.go. + +func explainerExtractParams(fd *ast.FuncDecl) [][2]string { + var params [][2]string + if fd.Type.Params == nil { + return params + } + for _, field := range fd.Type.Params.List { + typeStr := explainerExprToString(field.Type) + if len(field.Names) == 0 { + params = append(params, [2]string{"", typeStr}) + } + for _, name := range field.Names { + params = append(params, [2]string{name.Name, typeStr}) + } + } + return params +} + +func extractReturns(fd *ast.FuncDecl) []string { + var returns []string + if fd.Type.Results == nil { + return returns + } + for _, field := range fd.Type.Results.List { + typeStr := explainerExprToString(field.Type) + if len(field.Names) > 0 { + for range field.Names { + returns = append(returns, typeStr) + } + } else { + returns = append(returns, typeStr) + } + } + return returns +} + +func explainerExtractDocComment(fd *ast.FuncDecl) string { + if fd.Doc == nil { + return "" + } + text := strings.TrimSpace(fd.Doc.Text()) + if strings.HasPrefix(text, fd.Name.Name+" ") { + text = text[len(fd.Name.Name)+1:] + } + if idx := strings.Index(text, ". "); idx > 0 { + return text[:idx+1] + } + return text +} + +func extractFuncBody(content string, fd *ast.FuncDecl, fset *token.FileSet) string { + if fd.Body == nil { + return "" + } + start := fset.Position(fd.Body.Lbrace).Offset + end := fset.Position(fd.Body.Rbrace).Offset + if start >= 0 && end > start && end <= len(content) { + return content[start:end] + } + return "" +} + +func explainerExprToString(expr ast.Expr) string { + switch e := expr.(type) { + case *ast.Ident: + return e.Name + case *ast.SelectorExpr: + return explainerExprToString(e.X) + "." + e.Sel.Name + case *ast.StarExpr: + return "*" + explainerExprToString(e.X) + case *ast.ArrayType: + return "[]" + explainerExprToString(e.Elt) + case *ast.MapType: + return "map[" + explainerExprToString(e.Key) + "]" + explainerExprToString(e.Value) + case *ast.InterfaceType: + return "interface{}" + case *ast.FuncType: + return "func(...)" + case *ast.ChanType: + return "chan " + explainerExprToString(e.Value) + case *ast.Ellipsis: + return "..." + explainerExprToString(e.Elt) + default: + return "unknown" + } +} + +func inferParamPurpose(name, typeName string) string { + lower := strings.ToLower(name) + switch { + case lower == "ctx" || typeName == "context.Context": + return "Context for cancellation and deadlines" + case lower == "err" || typeName == "error": + return "Error to handle" + case strings.Contains(lower, "path") || strings.Contains(lower, "file"): + return "File path" + case strings.Contains(lower, "name"): + return "Name identifier" + case strings.Contains(lower, "id"): + return "Unique identifier" + case strings.Contains(lower, "timeout") || strings.Contains(lower, "duration"): + return "Time duration" + case strings.Contains(lower, "config") || strings.Contains(lower, "opts"): + return "Configuration options" + case strings.Contains(lower, "fn") || strings.Contains(lower, "func") || strings.Contains(lower, "callback"): + return "Callback function" + case strings.Contains(lower, "buf") || strings.Contains(lower, "data") || strings.Contains(lower, "bytes"): + return "Raw data buffer" + case strings.Contains(lower, "url") || strings.Contains(lower, "addr"): + return "Network address" + case strings.Contains(lower, "token"): + return "Authentication or parsing token" + case strings.Contains(lower, "key"): + return "Lookup key" + case strings.Contains(lower, "val") || strings.Contains(lower, "value"): + return "Value to process" + case typeName == "string": + return fmt.Sprintf("The %s string", name) + case typeName == "int" || typeName == "int64": + return fmt.Sprintf("The %s count or index", name) + case typeName == "bool": + return fmt.Sprintf("Whether to enable %s", name) + default: + return fmt.Sprintf("The %s to use", name) + } +} + +func inferTypePurpose(name string) string { + words := splitCamelCase(name) + if len(words) == 0 { + return "A type" + } + last := strings.ToLower(words[len(words)-1]) + prefix := "" + if len(words) > 1 { + prefix = strings.Join(words[:len(words)-1], " ") + } + + switch last { + case "config", "options", "opts", "settings": + return fmt.Sprintf("Configuration for %s", lowerFirst(prefix)) + case "handler": + return fmt.Sprintf("Handles %s operations", lowerFirst(prefix)) + case "service", "server": + return fmt.Sprintf("Provides %s functionality", lowerFirst(prefix)) + case "client": + return fmt.Sprintf("Client for communicating with %s", lowerFirst(prefix)) + case "store", "repository", "repo": + return fmt.Sprintf("Persistent storage for %s", lowerFirst(prefix)) + case "manager": + return fmt.Sprintf("Manages lifecycle of %s", lowerFirst(prefix)) + case "builder": + return fmt.Sprintf("Builder pattern for constructing %s", lowerFirst(prefix)) + case "error", "err": + return fmt.Sprintf("Error type for %s failures", lowerFirst(prefix)) + case "result": + return fmt.Sprintf("Result of %s operation", lowerFirst(prefix)) + case "request", "req": + return fmt.Sprintf("Request payload for %s", lowerFirst(prefix)) + case "response", "resp": + return fmt.Sprintf("Response from %s", lowerFirst(prefix)) + default: + return fmt.Sprintf("Represents %s", lowerFirst(strings.Join(words, " "))) + } +} + +func inferFieldPurpose(name, typeName string) string { + lower := strings.ToLower(name) + switch { + case lower == "mu" || strings.Contains(typeName, "Mutex"): + return "Protects concurrent access" + case lower == "ctx": + return "Context for cancellation" + case lower == "id": + return "Unique identifier" + case strings.Contains(lower, "err"): + return "Last error encountered" + case strings.Contains(lower, "done") || strings.Contains(lower, "closed"): + return "Signals completion or shutdown" + case strings.Contains(lower, "count") || strings.Contains(lower, "num"): + return "Counter value" + case strings.Contains(lower, "max") || strings.Contains(lower, "limit"): + return "Upper bound constraint" + case strings.Contains(lower, "min"): + return "Lower bound constraint" + case strings.Contains(lower, "timeout"): + return "Maximum wait duration" + case strings.Contains(lower, "name"): + return "Human-readable name" + case strings.Contains(lower, "path") || strings.Contains(lower, "dir"): + return "Filesystem path" + default: + return fmt.Sprintf("The %s value", name) + } +} + +func splitCamelCase(s string) []string { + var words []string + current := strings.Builder{} + for i, r := range s { + if i > 0 && r >= 'A' && r <= 'Z' { + if current.Len() > 0 { + words = append(words, current.String()) + current.Reset() + } + } + current.WriteRune(r) + } + if current.Len() > 0 { + words = append(words, current.String()) + } + return words +} + +func lowerFirst(s string) string { + if s == "" { + return s + } + return strings.ToLower(s[:1]) + s[1:] +} + +func containsType(types []string, target string) bool { + for _, t := range types { + if t == target || strings.Contains(t, target) { + return true + } + } + return false +} + +func computeCyclomaticComplexity(body string) int { + cc := 1 + patterns := []string{ + `\bif\b`, `\belse if\b`, `\bfor\b`, `\bcase\b`, + `&&`, `\|\|`, `\bselect\b`, + } + for _, p := range patterns { + re := regexp.MustCompile(p) + cc += len(re.FindAllString(body, -1)) + } + return cc +} + +func classifyComplexity(cc int) string { + switch { + case cc <= 5: + return "Low" + case cc <= 10: + return "Moderate" + case cc <= 20: + return "High" + default: + return "Very High" + } +} + +func describeErrorHandling(body string) string { + hasErrCheck := regexp.MustCompile(`if\s+err\s*!=\s*nil`).MatchString(body) + hasWrap := regexp.MustCompile(`fmt\.Errorf\(.*%w`).MatchString(body) + hasPanic := regexp.MustCompile(`\bpanic\(`).MatchString(body) + hasRecover := regexp.MustCompile(`\brecover\(\)`).MatchString(body) + + if !hasErrCheck && !hasPanic { + return "No explicit error handling" + } + + var parts []string + if hasErrCheck && hasWrap { + parts = append(parts, "Returns wrapped errors with context") + } else if hasErrCheck { + parts = append(parts, "Checks and propagates errors") + } + if hasPanic { + parts = append(parts, "May panic on unrecoverable errors") + } + if hasRecover { + parts = append(parts, "Includes panic recovery") + } + + if len(parts) == 0 { + return "Basic error checking" + } + return strings.Join(parts, "; ") +} + +func extractDependencies(body string) []string { + var deps []string + re := regexp.MustCompile(`\b([a-z][a-z0-9]+)\.\w+`) + matches := re.FindAllStringSubmatch(body, -1) + seen := map[string]bool{} + for _, m := range matches { + pkg := m[1] + if !seen[pkg] && pkg != "err" && pkg != "nil" { + seen[pkg] = true + deps = append(deps, pkg) + } + } + return deps +} + +func funcSignature(fd *ast.FuncDecl) string { + var sb strings.Builder + sb.WriteString(fd.Name.Name) + sb.WriteString("(") + if fd.Type.Params != nil { + var params []string + for _, field := range fd.Type.Params.List { + typeStr := explainerExprToString(field.Type) + if len(field.Names) > 0 { + for _, name := range field.Names { + params = append(params, name.Name+" "+typeStr) + } + } else { + params = append(params, typeStr) + } + } + sb.WriteString(strings.Join(params, ", ")) + } + sb.WriteString(")") + if fd.Type.Results != nil && len(fd.Type.Results.List) > 0 { + var results []string + for _, field := range fd.Type.Results.List { + results = append(results, explainerExprToString(field.Type)) + } + if len(results) == 1 { + sb.WriteString(" " + results[0]) + } else { + sb.WriteString(" (" + strings.Join(results, ", ") + ")") + } + } + return sb.String() +} + +func findConstructor(f *ast.File, typeName string) string { + candidates := []string{ + "New" + typeName, + "new" + typeName, + } + for _, decl := range f.Decls { + fd, ok := decl.(*ast.FuncDecl) + if !ok || fd.Recv != nil { + continue + } + for _, c := range candidates { + if fd.Name.Name == c { + return c + } + } + } + return "" +} + +func detectImplementedInterfaces(f *ast.File, typeName string, methods []string) []string { + var ifaces []string + methodSet := map[string]bool{} + for _, m := range methods { + methodSet[m] = true + } + + if methodSet["String"] { + ifaces = append(ifaces, "fmt.Stringer") + } + if methodSet["Error"] { + ifaces = append(ifaces, "error") + } + if methodSet["Read"] { + ifaces = append(ifaces, "io.Reader") + } + if methodSet["Write"] { + ifaces = append(ifaces, "io.Writer") + } + if methodSet["Close"] { + ifaces = append(ifaces, "io.Closer") + } + if methodSet["ServeHTTP"] { + ifaces = append(ifaces, "http.Handler") + } + if methodSet["MarshalJSON"] { + ifaces = append(ifaces, "json.Marshaler") + } + if methodSet["UnmarshalJSON"] { + ifaces = append(ifaces, "json.Unmarshaler") + } + if methodSet["Len"] && methodSet["Less"] && methodSet["Swap"] { + ifaces = append(ifaces, "sort.Interface") + } + + return ifaces +} + +func detectPatterns(content string) []string { + var patterns []string + + if regexp.MustCompile(`sync\.(Mutex|RWMutex)`).MatchString(content) { + patterns = append(patterns, "- Mutex-based concurrency control") + } + if regexp.MustCompile(`sync\.Once`).MatchString(content) { + patterns = append(patterns, "- Singleton/once initialization") + } + if regexp.MustCompile(`chan\s+\w`).MatchString(content) { + patterns = append(patterns, "- Channel-based communication") + } + if regexp.MustCompile(`context\.Context`).MatchString(content) { + patterns = append(patterns, "- Context propagation for cancellation") + } + if regexp.MustCompile(`interface\s*\{`).MatchString(content) { + patterns = append(patterns, "- Interface-based abstraction") + } + if regexp.MustCompile(`func\s+New\w+\(`).MatchString(content) { + patterns = append(patterns, "- Constructor functions (New* pattern)") + } + if regexp.MustCompile(`defer\s+`).MatchString(content) { + patterns = append(patterns, "- Deferred cleanup") + } + if regexp.MustCompile(`fmt\.Errorf\(.*%w`).MatchString(content) { + patterns = append(patterns, "- Error wrapping with context") + } + if regexp.MustCompile(`select\s*\{`).MatchString(content) { + patterns = append(patterns, "- Select-based multiplexing") + } + if regexp.MustCompile(`type\s+\w+\s+struct\s*\{[^}]*\w+\s+interface`).MatchString(content) { + patterns = append(patterns, "- Dependency injection via interfaces") + } + + return patterns +} From 920d23121a7a1cbf7349272d05bc3586b07fa902 Mon Sep 17 00:00:00 2001 From: Lakshman Patel Date: Fri, 19 Jun 2026 19:26:25 +0530 Subject: [PATCH 14/20] refactor(repomap): split health score dimension scorers into health_score_dimensions.go --- internal/intelligence/repomap/health_score.go | 601 +---------------- .../repomap/health_score_dimensions.go | 613 ++++++++++++++++++ 2 files changed, 616 insertions(+), 598 deletions(-) create mode 100644 internal/intelligence/repomap/health_score_dimensions.go diff --git a/internal/intelligence/repomap/health_score.go b/internal/intelligence/repomap/health_score.go index 1e2d35df..565df4d8 100644 --- a/internal/intelligence/repomap/health_score.go +++ b/internal/intelligence/repomap/health_score.go @@ -8,9 +8,6 @@ package repomap import ( "fmt" "go/ast" - "go/parser" - "go/token" - "io/fs" "os" "path/filepath" "sort" @@ -131,601 +128,9 @@ func (hs *HealthScorer) Score(projectDir string) (*HealthScore, error) { return result, nil } -// ScoreTestCoverage evaluates testing practices and coverage. -func (hs *HealthScorer) ScoreTestCoverage(dir string) (float64, []HealthIssue) { - var issues []HealthIssue - sourceFiles := 0 - testFiles := 0 - dirsWithSource := make(map[string]bool) - dirsWithTests := make(map[string]bool) - - _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { - if err != nil { - return err - } - if d.IsDir() { - name := d.Name() - if name == ".git" || name == "vendor" || name == "node_modules" { - return filepath.SkipDir - } - return nil - } - if !strings.HasSuffix(path, ".go") { - return nil - } - relDir := filepath.Dir(path) - if strings.HasSuffix(path, "_test.go") { - testFiles++ - dirsWithTests[relDir] = true - } else { - sourceFiles++ - dirsWithSource[relDir] = true - } - return nil - }) - - if sourceFiles == 0 { - return 100.0, issues - } - - // Calculate test-to-source ratio - ratio := float64(testFiles) / float64(sourceFiles) - ratioScore := ratio * 100.0 - if ratioScore > 100.0 { - ratioScore = 100.0 - } - - // Check directories without tests - for d := range dirsWithSource { - if !dirsWithTests[d] { - rel, _ := filepath.Rel(dir, d) - if rel == "" { - rel = d - } - issues = append(issues, HealthIssue{ - Dimension: "test_coverage", - Description: fmt.Sprintf("%s has no tests", rel), - Severity: "warning", - File: d, - Suggestion: fmt.Sprintf("Add test files to %s", rel), - }) - } - } - - // Penalize for directories without tests - if len(dirsWithSource) > 0 { - coverageRatio := float64(len(dirsWithTests)) / float64(len(dirsWithSource)) - score := (ratioScore*0.6 + coverageRatio*100.0*0.4) - if score > 100.0 { - score = 100.0 - } - return score, issues - } - - return ratioScore, issues -} - -// ScoreDocumentation evaluates documentation quality. -func (hs *HealthScorer) ScoreDocumentation(dir string) (float64, []HealthIssue) { - var issues []HealthIssue - score := 0.0 - checks := 0.0 - - // Check for README - readmeExists := false - readmeNames := []string{"README.md", "README", "README.txt", "readme.md"} - for _, name := range readmeNames { - if _, err := os.Stat(filepath.Join(dir, name)); err == nil { - readmeExists = true - break - } - } - checks++ - if readmeExists { - score += 100.0 - } else { - issues = append(issues, HealthIssue{ - Dimension: "documentation", - Description: "No README file found", - Severity: "warning", - File: dir, - Suggestion: "Add a README.md with project overview and usage instructions", - }) - } - - // Analyze exported function documentation - totalExported := 0 - documentedExported := 0 - - _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { - if err != nil { - return err - } - if d.IsDir() { - name := d.Name() - if name == ".git" || name == "vendor" || name == "node_modules" || name == "testdata" { - return filepath.SkipDir - } - return nil - } - if !strings.HasSuffix(path, ".go") || strings.HasSuffix(path, "_test.go") { - return nil - } - - fset := token.NewFileSet() - f, parseErr := parser.ParseFile(fset, path, nil, parser.ParseComments) - if parseErr != nil { - return nil - } - - for _, decl := range f.Decls { - fn, ok := decl.(*ast.FuncDecl) - if !ok { - continue - } - if !fn.Name.IsExported() { - continue - } - totalExported++ - if fn.Doc != nil && len(fn.Doc.List) > 0 { - documentedExported++ - } - } - return nil - }) - - if totalExported > 0 { - checks++ - docRatio := float64(documentedExported) / float64(totalExported) * 100.0 - score += docRatio - - if docRatio < 50.0 { - issues = append(issues, HealthIssue{ - Dimension: "documentation", - Description: fmt.Sprintf("Only %.0f%% of exported functions are documented", docRatio), - Severity: "warning", - File: dir, - Suggestion: "Add doc comments to exported functions following Go conventions", - }) - } - } - - if checks == 0 { - return 100.0, issues - } - return score / checks, issues -} - -// ScoreComplexity evaluates code complexity across the project. -func (hs *HealthScorer) ScoreComplexity(dir string) (float64, []HealthIssue) { - var issues []HealthIssue - var complexities []int - highComplexityCount := 0 - threshold := 10 - - _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { - if err != nil { - return err - } - if d.IsDir() { - name := d.Name() - if name == ".git" || name == "vendor" || name == "node_modules" || name == "testdata" { - return filepath.SkipDir - } - return nil - } - if !strings.HasSuffix(path, ".go") || strings.HasSuffix(path, "_test.go") { - return nil - } - - fset := token.NewFileSet() - f, parseErr := parser.ParseFile(fset, path, nil, parser.ParseComments) - if parseErr != nil { - return nil - } - - for _, decl := range f.Decls { - fn, ok := decl.(*ast.FuncDecl) - if !ok { - continue - } - cc := calculateCyclomaticComplexity(fn) - complexities = append(complexities, cc) - if cc > threshold { - highComplexityCount++ - rel, _ := filepath.Rel(dir, path) - if rel == "" { - rel = path - } - issues = append(issues, HealthIssue{ - Dimension: "complexity", - Description: fmt.Sprintf("Function %s has cyclomatic complexity %d (threshold: %d)", fn.Name.Name, cc, threshold), - Severity: severityForComplexity(cc), - File: rel, - Suggestion: fmt.Sprintf("Refactor %s to reduce branching", fn.Name.Name), - }) - } - } - return nil - }) - - if len(complexities) == 0 { - return 100.0, issues - } - - // Calculate average complexity - total := 0 - for _, c := range complexities { - total += c - } - avg := float64(total) / float64(len(complexities)) - - // Score based on average and high-complexity function ratio - avgScore := 100.0 - (avg-1.0)*10.0 - if avgScore < 0 { - avgScore = 0 - } - if avgScore > 100.0 { - avgScore = 100.0 - } - - highRatio := float64(highComplexityCount) / float64(len(complexities)) - ratioScore := (1.0 - highRatio) * 100.0 - - score := avgScore*0.6 + ratioScore*0.4 - if score > 100.0 { - score = 100.0 - } - if score < 0 { - score = 0 - } - - return score, issues -} - -// ScoreDependencies evaluates dependency health. -func (hs *HealthScorer) ScoreDependencies(dir string) (float64, []HealthIssue) { - var issues []HealthIssue - score := 100.0 - - // Check for go.mod - goModPath := filepath.Join(dir, "go.mod") - data, err := os.ReadFile(goModPath) - if err != nil { - // No go.mod — might be a simple project or not Go - return 80.0, issues - } - - lines := strings.Split(string(data), "\n") - depCount := 0 - inRequire := false - - for _, line := range lines { - trimmed := strings.TrimSpace(line) - if trimmed == "require (" { - inRequire = true - continue - } - if inRequire && trimmed == ")" { - inRequire = false - continue - } - if inRequire && trimmed != "" && !strings.HasPrefix(trimmed, "//") { - depCount++ - } - // Single-line require - if strings.HasPrefix(trimmed, "require ") && !strings.Contains(trimmed, "(") { - depCount++ - } - } - - // Penalize for excessive dependencies - if depCount > 50 { - penalty := float64(depCount-50) * 0.5 - if penalty > 30 { - penalty = 30 - } - score -= penalty - issues = append(issues, HealthIssue{ - Dimension: "dependencies", - Description: fmt.Sprintf("High dependency count: %d direct dependencies", depCount), - Severity: "warning", - File: "go.mod", - Suggestion: "Review dependencies for unused or replaceable modules", - }) - } - - // Check for replace directives (might indicate instability) - replaceCount := 0 - for _, line := range lines { - trimmed := strings.TrimSpace(line) - if strings.HasPrefix(trimmed, "replace ") { - replaceCount++ - } - } - if replaceCount > 3 { - score -= float64(replaceCount) * 2 - issues = append(issues, HealthIssue{ - Dimension: "dependencies", - Description: fmt.Sprintf("%d replace directives found", replaceCount), - Severity: "info", - File: "go.mod", - Suggestion: "Replace directives may indicate unstable dependencies", - }) - } - - if score < 0 { - score = 0 - } - return score, issues -} - -// ScoreCodeQuality evaluates overall code quality signals. -func (hs *HealthScorer) ScoreCodeQuality(dir string) (float64, []HealthIssue) { - var issues []HealthIssue - totalFiles := 0 - filesWithIssues := 0 - - var longFiles []string - var deadCodeFiles []string - - _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { - if err != nil { - return err - } - if d.IsDir() { - name := d.Name() - if name == ".git" || name == "vendor" || name == "node_modules" || name == "testdata" { - return filepath.SkipDir - } - return nil - } - if !strings.HasSuffix(path, ".go") || strings.HasSuffix(path, "_test.go") { - return nil - } - - totalFiles++ - hasIssue := false - - data, readErr := os.ReadFile(path) - if readErr != nil { - return nil - } - lines := strings.Split(string(data), "\n") - - // Check file length - if len(lines) > 500 { - hasIssue = true - rel, _ := filepath.Rel(dir, path) - if rel == "" { - rel = path - } - longFiles = append(longFiles, rel) - } - - // Check for potential dead code (commented-out functions) - commentedFuncCount := 0 - for _, line := range lines { - trimmed := strings.TrimSpace(line) - if strings.HasPrefix(trimmed, "// func ") || strings.HasPrefix(trimmed, "//func ") { - commentedFuncCount++ - } - } - if commentedFuncCount > 2 { - hasIssue = true - rel, _ := filepath.Rel(dir, path) - if rel == "" { - rel = path - } - deadCodeFiles = append(deadCodeFiles, rel) - } - - if hasIssue { - filesWithIssues++ - } - return nil - }) - - if len(longFiles) > 0 { - issues = append(issues, HealthIssue{ - Dimension: "code_quality", - Description: fmt.Sprintf("%d files exceed 500 lines", len(longFiles)), - Severity: "warning", - File: longFiles[0], - Suggestion: "Consider splitting large files into smaller, focused modules", - }) - } - - if len(deadCodeFiles) > 0 { - issues = append(issues, HealthIssue{ - Dimension: "code_quality", - Description: fmt.Sprintf("%d files contain commented-out code", len(deadCodeFiles)), - Severity: "info", - File: deadCodeFiles[0], - Suggestion: "Remove dead code; use version control for history", - }) - } - - if totalFiles == 0 { - return 100.0, issues - } - - qualityRatio := 1.0 - float64(filesWithIssues)/float64(totalFiles) - score := qualityRatio * 100.0 - if score < 0 { - score = 0 - } - return score, issues -} - -// ScoreMaintainability evaluates how maintainable the codebase is. -func (hs *HealthScorer) ScoreMaintainability(dir string) (float64, []HealthIssue) { - var issues []HealthIssue - score := 100.0 - - // Check package organization - pkgCount := 0 - _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { - if err != nil { - return err - } - if !d.IsDir() { - return nil - } - name := d.Name() - if name == ".git" || name == "vendor" || name == "node_modules" { - return filepath.SkipDir - } - // Check if directory contains Go files - entries, _ := os.ReadDir(path) - for _, e := range entries { - if !e.IsDir() && strings.HasSuffix(e.Name(), ".go") { - pkgCount++ - break - } - } - return nil - }) - - // Check naming consistency - inconsistentNames := 0 - _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { - if err != nil { - return err - } - if d.IsDir() { - name := d.Name() - if name == ".git" || name == "vendor" || name == "node_modules" { - return filepath.SkipDir - } - return nil - } - if !strings.HasSuffix(path, ".go") || strings.HasSuffix(path, "_test.go") { - return nil - } - base := filepath.Base(path) - // Check for mixed naming conventions (camelCase vs snake_case) - if strings.Contains(base, "-") { - inconsistentNames++ - } - return nil - }) - - if inconsistentNames > 0 { - score -= float64(inconsistentNames) * 5 - issues = append(issues, HealthIssue{ - Dimension: "maintainability", - Description: fmt.Sprintf("%d files use non-standard naming", inconsistentNames), - Severity: "info", - File: dir, - Suggestion: "Use snake_case for Go file names", - }) - } - - // Check for consistent error handling patterns - errPatterns := checkErrorPatterns(dir) - if errPatterns < 0.7 { - score -= 15 - issues = append(issues, HealthIssue{ - Dimension: "maintainability", - Description: "Inconsistent error handling patterns detected", - Severity: "warning", - File: dir, - Suggestion: "Standardize error handling with wrapped errors using fmt.Errorf with %w", - }) - } - - if score < 0 { - score = 0 - } - if score > 100 { - score = 100 - } - return score, issues -} - -// ScoreSecurity evaluates basic security signals in the codebase. -func (hs *HealthScorer) ScoreSecurity(dir string) (float64, []HealthIssue) { - var issues []HealthIssue - score := 100.0 - - dangerousPatterns := []struct { - pattern string - description string - severity string - suggestion string - }{ - {"exec.Command", "Use of exec.Command may allow command injection", "warning", "Validate and sanitize all inputs to exec.Command"}, - {"os.Exec", "Direct OS exec calls detected", "warning", "Ensure executed commands are properly validated"}, - {"net/http", "HTTP usage without explicit TLS configuration", "info", "Consider enforcing HTTPS for external connections"}, - {"crypto/md5", "MD5 is cryptographically broken", "error", "Replace MD5 with SHA-256 or stronger"}, - {"crypto/sha1", "SHA-1 is deprecated for security purposes", "warning", "Replace SHA-1 with SHA-256 or stronger"}, - {"unsafe.Pointer", "Use of unsafe package bypasses type safety", "warning", "Avoid unsafe unless absolutely necessary"}, - } - - foundPatterns := make(map[string][]string) - - _ = filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { - if err != nil { - return nil - } - if info.IsDir() { - name := info.Name() - if name == ".git" || name == "vendor" || name == "node_modules" || name == "testdata" { - return filepath.SkipDir - } - return nil - } - if !strings.HasSuffix(path, ".go") || strings.HasSuffix(path, "_test.go") { - return nil - } - - data, readErr := os.ReadFile(path) - if readErr != nil { - return nil - } - content := string(data) - rel, _ := filepath.Rel(dir, path) - if rel == "" { - rel = path - } - - for _, dp := range dangerousPatterns { - if strings.Contains(content, dp.pattern) { - foundPatterns[dp.pattern] = append(foundPatterns[dp.pattern], rel) - } - } - return nil - }) - - for _, dp := range dangerousPatterns { - files := foundPatterns[dp.pattern] - if len(files) > 0 { - var penalty float64 - switch dp.severity { - case "error": - penalty = 15 - case "warning": - penalty = 8 - case "info": - penalty = 3 - } - score -= penalty - issues = append(issues, HealthIssue{ - Dimension: "security", - Description: fmt.Sprintf("%s (found in %d files)", dp.description, len(files)), - Severity: dp.severity, - File: files[0], - Suggestion: dp.suggestion, - }) - } - } - - if score < 0 { - score = 0 - } - return score, issues -} +// The per-dimension scorer methods (ScoreTestCoverage, ScoreDocumentation, +// ScoreComplexity, ScoreDependencies, ScoreCodeQuality, ScoreMaintainability, +// ScoreSecurity) live in health_score_dimensions.go. // FormatScore produces a human-readable health report. func FormatScore(score *HealthScore) string { diff --git a/internal/intelligence/repomap/health_score_dimensions.go b/internal/intelligence/repomap/health_score_dimensions.go new file mode 100644 index 00000000..aef93b13 --- /dev/null +++ b/internal/intelligence/repomap/health_score_dimensions.go @@ -0,0 +1,613 @@ +package repomap + +import ( + "fmt" + "go/ast" + "go/parser" + "go/token" + "io/fs" + "os" + "path/filepath" + "strings" +) + +// This file holds the per-dimension HealthScorer methods (test coverage, +// documentation, complexity, dependencies, code quality, maintainability, +// security). The HealthScorer type, aggregation (Score), reporting +// (FormatScore/CompareScores), and shared helpers live in health_score.go. + +// ScoreTestCoverage evaluates testing practices and coverage. +func (hs *HealthScorer) ScoreTestCoverage(dir string) (float64, []HealthIssue) { + var issues []HealthIssue + sourceFiles := 0 + testFiles := 0 + dirsWithSource := make(map[string]bool) + dirsWithTests := make(map[string]bool) + + _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() { + name := d.Name() + if name == ".git" || name == "vendor" || name == "node_modules" { + return filepath.SkipDir + } + return nil + } + if !strings.HasSuffix(path, ".go") { + return nil + } + relDir := filepath.Dir(path) + if strings.HasSuffix(path, "_test.go") { + testFiles++ + dirsWithTests[relDir] = true + } else { + sourceFiles++ + dirsWithSource[relDir] = true + } + return nil + }) + + if sourceFiles == 0 { + return 100.0, issues + } + + // Calculate test-to-source ratio + ratio := float64(testFiles) / float64(sourceFiles) + ratioScore := ratio * 100.0 + if ratioScore > 100.0 { + ratioScore = 100.0 + } + + // Check directories without tests + for d := range dirsWithSource { + if !dirsWithTests[d] { + rel, _ := filepath.Rel(dir, d) + if rel == "" { + rel = d + } + issues = append(issues, HealthIssue{ + Dimension: "test_coverage", + Description: fmt.Sprintf("%s has no tests", rel), + Severity: "warning", + File: d, + Suggestion: fmt.Sprintf("Add test files to %s", rel), + }) + } + } + + // Penalize for directories without tests + if len(dirsWithSource) > 0 { + coverageRatio := float64(len(dirsWithTests)) / float64(len(dirsWithSource)) + score := (ratioScore*0.6 + coverageRatio*100.0*0.4) + if score > 100.0 { + score = 100.0 + } + return score, issues + } + + return ratioScore, issues +} + +// ScoreDocumentation evaluates documentation quality. +func (hs *HealthScorer) ScoreDocumentation(dir string) (float64, []HealthIssue) { + var issues []HealthIssue + score := 0.0 + checks := 0.0 + + // Check for README + readmeExists := false + readmeNames := []string{"README.md", "README", "README.txt", "readme.md"} + for _, name := range readmeNames { + if _, err := os.Stat(filepath.Join(dir, name)); err == nil { + readmeExists = true + break + } + } + checks++ + if readmeExists { + score += 100.0 + } else { + issues = append(issues, HealthIssue{ + Dimension: "documentation", + Description: "No README file found", + Severity: "warning", + File: dir, + Suggestion: "Add a README.md with project overview and usage instructions", + }) + } + + // Analyze exported function documentation + totalExported := 0 + documentedExported := 0 + + _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() { + name := d.Name() + if name == ".git" || name == "vendor" || name == "node_modules" || name == "testdata" { + return filepath.SkipDir + } + return nil + } + if !strings.HasSuffix(path, ".go") || strings.HasSuffix(path, "_test.go") { + return nil + } + + fset := token.NewFileSet() + f, parseErr := parser.ParseFile(fset, path, nil, parser.ParseComments) + if parseErr != nil { + return nil + } + + for _, decl := range f.Decls { + fn, ok := decl.(*ast.FuncDecl) + if !ok { + continue + } + if !fn.Name.IsExported() { + continue + } + totalExported++ + if fn.Doc != nil && len(fn.Doc.List) > 0 { + documentedExported++ + } + } + return nil + }) + + if totalExported > 0 { + checks++ + docRatio := float64(documentedExported) / float64(totalExported) * 100.0 + score += docRatio + + if docRatio < 50.0 { + issues = append(issues, HealthIssue{ + Dimension: "documentation", + Description: fmt.Sprintf("Only %.0f%% of exported functions are documented", docRatio), + Severity: "warning", + File: dir, + Suggestion: "Add doc comments to exported functions following Go conventions", + }) + } + } + + if checks == 0 { + return 100.0, issues + } + return score / checks, issues +} + +// ScoreComplexity evaluates code complexity across the project. +func (hs *HealthScorer) ScoreComplexity(dir string) (float64, []HealthIssue) { + var issues []HealthIssue + var complexities []int + highComplexityCount := 0 + threshold := 10 + + _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() { + name := d.Name() + if name == ".git" || name == "vendor" || name == "node_modules" || name == "testdata" { + return filepath.SkipDir + } + return nil + } + if !strings.HasSuffix(path, ".go") || strings.HasSuffix(path, "_test.go") { + return nil + } + + fset := token.NewFileSet() + f, parseErr := parser.ParseFile(fset, path, nil, parser.ParseComments) + if parseErr != nil { + return nil + } + + for _, decl := range f.Decls { + fn, ok := decl.(*ast.FuncDecl) + if !ok { + continue + } + cc := calculateCyclomaticComplexity(fn) + complexities = append(complexities, cc) + if cc > threshold { + highComplexityCount++ + rel, _ := filepath.Rel(dir, path) + if rel == "" { + rel = path + } + issues = append(issues, HealthIssue{ + Dimension: "complexity", + Description: fmt.Sprintf("Function %s has cyclomatic complexity %d (threshold: %d)", fn.Name.Name, cc, threshold), + Severity: severityForComplexity(cc), + File: rel, + Suggestion: fmt.Sprintf("Refactor %s to reduce branching", fn.Name.Name), + }) + } + } + return nil + }) + + if len(complexities) == 0 { + return 100.0, issues + } + + // Calculate average complexity + total := 0 + for _, c := range complexities { + total += c + } + avg := float64(total) / float64(len(complexities)) + + // Score based on average and high-complexity function ratio + avgScore := 100.0 - (avg-1.0)*10.0 + if avgScore < 0 { + avgScore = 0 + } + if avgScore > 100.0 { + avgScore = 100.0 + } + + highRatio := float64(highComplexityCount) / float64(len(complexities)) + ratioScore := (1.0 - highRatio) * 100.0 + + score := avgScore*0.6 + ratioScore*0.4 + if score > 100.0 { + score = 100.0 + } + if score < 0 { + score = 0 + } + + return score, issues +} + +// ScoreDependencies evaluates dependency health. +func (hs *HealthScorer) ScoreDependencies(dir string) (float64, []HealthIssue) { + var issues []HealthIssue + score := 100.0 + + // Check for go.mod + goModPath := filepath.Join(dir, "go.mod") + data, err := os.ReadFile(goModPath) + if err != nil { + // No go.mod — might be a simple project or not Go + return 80.0, issues + } + + lines := strings.Split(string(data), "\n") + depCount := 0 + inRequire := false + + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if trimmed == "require (" { + inRequire = true + continue + } + if inRequire && trimmed == ")" { + inRequire = false + continue + } + if inRequire && trimmed != "" && !strings.HasPrefix(trimmed, "//") { + depCount++ + } + // Single-line require + if strings.HasPrefix(trimmed, "require ") && !strings.Contains(trimmed, "(") { + depCount++ + } + } + + // Penalize for excessive dependencies + if depCount > 50 { + penalty := float64(depCount-50) * 0.5 + if penalty > 30 { + penalty = 30 + } + score -= penalty + issues = append(issues, HealthIssue{ + Dimension: "dependencies", + Description: fmt.Sprintf("High dependency count: %d direct dependencies", depCount), + Severity: "warning", + File: "go.mod", + Suggestion: "Review dependencies for unused or replaceable modules", + }) + } + + // Check for replace directives (might indicate instability) + replaceCount := 0 + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if strings.HasPrefix(trimmed, "replace ") { + replaceCount++ + } + } + if replaceCount > 3 { + score -= float64(replaceCount) * 2 + issues = append(issues, HealthIssue{ + Dimension: "dependencies", + Description: fmt.Sprintf("%d replace directives found", replaceCount), + Severity: "info", + File: "go.mod", + Suggestion: "Replace directives may indicate unstable dependencies", + }) + } + + if score < 0 { + score = 0 + } + return score, issues +} + +// ScoreCodeQuality evaluates overall code quality signals. +func (hs *HealthScorer) ScoreCodeQuality(dir string) (float64, []HealthIssue) { + var issues []HealthIssue + totalFiles := 0 + filesWithIssues := 0 + + var longFiles []string + var deadCodeFiles []string + + _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() { + name := d.Name() + if name == ".git" || name == "vendor" || name == "node_modules" || name == "testdata" { + return filepath.SkipDir + } + return nil + } + if !strings.HasSuffix(path, ".go") || strings.HasSuffix(path, "_test.go") { + return nil + } + + totalFiles++ + hasIssue := false + + data, readErr := os.ReadFile(path) + if readErr != nil { + return nil + } + lines := strings.Split(string(data), "\n") + + // Check file length + if len(lines) > 500 { + hasIssue = true + rel, _ := filepath.Rel(dir, path) + if rel == "" { + rel = path + } + longFiles = append(longFiles, rel) + } + + // Check for potential dead code (commented-out functions) + commentedFuncCount := 0 + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if strings.HasPrefix(trimmed, "// func ") || strings.HasPrefix(trimmed, "//func ") { + commentedFuncCount++ + } + } + if commentedFuncCount > 2 { + hasIssue = true + rel, _ := filepath.Rel(dir, path) + if rel == "" { + rel = path + } + deadCodeFiles = append(deadCodeFiles, rel) + } + + if hasIssue { + filesWithIssues++ + } + return nil + }) + + if len(longFiles) > 0 { + issues = append(issues, HealthIssue{ + Dimension: "code_quality", + Description: fmt.Sprintf("%d files exceed 500 lines", len(longFiles)), + Severity: "warning", + File: longFiles[0], + Suggestion: "Consider splitting large files into smaller, focused modules", + }) + } + + if len(deadCodeFiles) > 0 { + issues = append(issues, HealthIssue{ + Dimension: "code_quality", + Description: fmt.Sprintf("%d files contain commented-out code", len(deadCodeFiles)), + Severity: "info", + File: deadCodeFiles[0], + Suggestion: "Remove dead code; use version control for history", + }) + } + + if totalFiles == 0 { + return 100.0, issues + } + + qualityRatio := 1.0 - float64(filesWithIssues)/float64(totalFiles) + score := qualityRatio * 100.0 + if score < 0 { + score = 0 + } + return score, issues +} + +// ScoreMaintainability evaluates how maintainable the codebase is. +func (hs *HealthScorer) ScoreMaintainability(dir string) (float64, []HealthIssue) { + var issues []HealthIssue + score := 100.0 + + // Check package organization + pkgCount := 0 + _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if !d.IsDir() { + return nil + } + name := d.Name() + if name == ".git" || name == "vendor" || name == "node_modules" { + return filepath.SkipDir + } + // Check if directory contains Go files + entries, _ := os.ReadDir(path) + for _, e := range entries { + if !e.IsDir() && strings.HasSuffix(e.Name(), ".go") { + pkgCount++ + break + } + } + return nil + }) + + // Check naming consistency + inconsistentNames := 0 + _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() { + name := d.Name() + if name == ".git" || name == "vendor" || name == "node_modules" { + return filepath.SkipDir + } + return nil + } + if !strings.HasSuffix(path, ".go") || strings.HasSuffix(path, "_test.go") { + return nil + } + base := filepath.Base(path) + // Check for mixed naming conventions (camelCase vs snake_case) + if strings.Contains(base, "-") { + inconsistentNames++ + } + return nil + }) + + if inconsistentNames > 0 { + score -= float64(inconsistentNames) * 5 + issues = append(issues, HealthIssue{ + Dimension: "maintainability", + Description: fmt.Sprintf("%d files use non-standard naming", inconsistentNames), + Severity: "info", + File: dir, + Suggestion: "Use snake_case for Go file names", + }) + } + + // Check for consistent error handling patterns + errPatterns := checkErrorPatterns(dir) + if errPatterns < 0.7 { + score -= 15 + issues = append(issues, HealthIssue{ + Dimension: "maintainability", + Description: "Inconsistent error handling patterns detected", + Severity: "warning", + File: dir, + Suggestion: "Standardize error handling with wrapped errors using fmt.Errorf with %w", + }) + } + + if score < 0 { + score = 0 + } + if score > 100 { + score = 100 + } + return score, issues +} + +// ScoreSecurity evaluates basic security signals in the codebase. +func (hs *HealthScorer) ScoreSecurity(dir string) (float64, []HealthIssue) { + var issues []HealthIssue + score := 100.0 + + dangerousPatterns := []struct { + pattern string + description string + severity string + suggestion string + }{ + {"exec.Command", "Use of exec.Command may allow command injection", "warning", "Validate and sanitize all inputs to exec.Command"}, + {"os.Exec", "Direct OS exec calls detected", "warning", "Ensure executed commands are properly validated"}, + {"net/http", "HTTP usage without explicit TLS configuration", "info", "Consider enforcing HTTPS for external connections"}, + {"crypto/md5", "MD5 is cryptographically broken", "error", "Replace MD5 with SHA-256 or stronger"}, + {"crypto/sha1", "SHA-1 is deprecated for security purposes", "warning", "Replace SHA-1 with SHA-256 or stronger"}, + {"unsafe.Pointer", "Use of unsafe package bypasses type safety", "warning", "Avoid unsafe unless absolutely necessary"}, + } + + foundPatterns := make(map[string][]string) + + _ = filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return nil + } + if info.IsDir() { + name := info.Name() + if name == ".git" || name == "vendor" || name == "node_modules" || name == "testdata" { + return filepath.SkipDir + } + return nil + } + if !strings.HasSuffix(path, ".go") || strings.HasSuffix(path, "_test.go") { + return nil + } + + data, readErr := os.ReadFile(path) + if readErr != nil { + return nil + } + content := string(data) + rel, _ := filepath.Rel(dir, path) + if rel == "" { + rel = path + } + + for _, dp := range dangerousPatterns { + if strings.Contains(content, dp.pattern) { + foundPatterns[dp.pattern] = append(foundPatterns[dp.pattern], rel) + } + } + return nil + }) + + for _, dp := range dangerousPatterns { + files := foundPatterns[dp.pattern] + if len(files) > 0 { + var penalty float64 + switch dp.severity { + case "error": + penalty = 15 + case "warning": + penalty = 8 + case "info": + penalty = 3 + } + score -= penalty + issues = append(issues, HealthIssue{ + Dimension: "security", + Description: fmt.Sprintf("%s (found in %d files)", dp.description, len(files)), + Severity: dp.severity, + File: files[0], + Suggestion: dp.suggestion, + }) + } + } + + if score < 0 { + score = 0 + } + return score, issues +} From 32436004441b3d5c8c3b9965e457acbcd763a3a1 Mon Sep 17 00:00:00 2001 From: Lakshman Patel Date: Fri, 19 Jun 2026 19:31:00 +0530 Subject: [PATCH 15/20] refactor(codegraph): split helpers and read/traversal ops into codegraph_cgo_query.go --- internal/codegraph/codegraph_cgo.go | 575 +-------------------- internal/codegraph/codegraph_cgo_query.go | 593 ++++++++++++++++++++++ 2 files changed, 596 insertions(+), 572 deletions(-) create mode 100644 internal/codegraph/codegraph_cgo_query.go diff --git a/internal/codegraph/codegraph_cgo.go b/internal/codegraph/codegraph_cgo.go index da76f046..62f95e79 100644 --- a/internal/codegraph/codegraph_cgo.go +++ b/internal/codegraph/codegraph_cgo.go @@ -4,7 +4,6 @@ package codegraph import ( "context" - "crypto/sha256" "database/sql" "fmt" "os" @@ -14,10 +13,6 @@ import ( "time" sitter "github.com/smacker/go-tree-sitter" - "github.com/smacker/go-tree-sitter/golang" - "github.com/smacker/go-tree-sitter/python" - "github.com/smacker/go-tree-sitter/typescript/tsx" - tstype "github.com/smacker/go-tree-sitter/typescript/typescript" ) // CodeGraph is a tree-sitter based code knowledge graph. @@ -918,570 +913,6 @@ func (cg *CodeGraph) extractNode(node *sitter.Node, source []byte, filePath, ext } } -func extractGoSignature(node *sitter.Node, source []byte) string { - // Extract the full function signature line - start := node.StartByte() - bodyNode := node.ChildByFieldName("body") - if bodyNode != nil { - // Get everything before the body - return strings.TrimSpace(string(source[start:bodyNode.StartByte()])) - } - // Fallback: get the first line - text := string(source[start:node.EndByte()]) - lines := strings.SplitN(text, "\n", 2) - return strings.TrimSpace(lines[0]) -} - -func extractCalleeName(node *sitter.Node, source []byte) string { - // For call_expression, get the function name - funcNode := node.ChildByFieldName("function") - if funcNode == nil { - // Try first child - if node.NamedChildCount() > 0 { - funcNode = node.NamedChild(0) - } - } - if funcNode == nil { - return "" - } - - text := string(source[funcNode.StartByte():funcNode.EndByte()]) - // Remove arguments part - if idx := strings.Index(text, "("); idx > 0 { - text = text[:idx] - } - return strings.TrimSpace(text) -} - -func extractDocstring(node *sitter.Node, source []byte) string { - // Look for comment node before this node - parent := node.Parent() - if parent == nil { - return "" - } - - for i := 0; i < int(parent.NamedChildCount()); i++ { - child := parent.NamedChild(i) - if child.Equal(node) && i > 0 { - prev := parent.NamedChild(i - 1) - if prev.Type() == "comment" || prev.Type() == "block_comment" { - text := string(source[prev.StartByte():prev.EndByte()]) - // Clean comment markers - text = strings.TrimPrefix(text, "//") - text = strings.TrimPrefix(text, "/*") - text = strings.TrimSuffix(text, "*/") - return strings.TrimSpace(text) - } - } - } - return "" -} - -func generateNodeID(filePath, kind, name string, line int) string { - h := sha256.Sum256([]byte(fmt.Sprintf("%s:%s:%s:%d", filePath, kind, name, line))) - return fmt.Sprintf("%x", h[:8]) -} - -func sha256Sum(data []byte) string { - h := sha256.Sum256(data) - return fmt.Sprintf("%x", h[:8]) -} - -func getLanguage(ext string) *sitter.Language { - switch ext { - case ".go": - return golang.GetLanguage() - case ".py": - return python.GetLanguage() - case ".ts": - return tstype.GetLanguage() - case ".tsx", ".js", ".jsx": - return tsx.GetLanguage() - default: - return golang.GetLanguage() // fallback - } -} - -func extractSearchTerms(query string) []string { - // Split on spaces and camelCase boundaries - var terms []string - words := strings.Fields(query) - for _, w := range words { - // Split camelCase - parts := splitCamelCase(w) - terms = append(terms, parts...) - } - return terms -} - -func splitCamelCase(s string) []string { - var parts []string - var current strings.Builder - - for i, r := range s { - if i > 0 && r >= 'A' && r <= 'Z' { - if current.Len() > 0 { - parts = append(parts, current.String()) - current.Reset() - } - } - current.WriteRune(r) - } - if current.Len() > 0 { - parts = append(parts, current.String()) - } - return parts -} - -func contains(slice []string, item string) bool { - for _, s := range slice { - if s == item { - return true - } - } - return false -} - -func boolToInt(b bool) int { - if b { - return 1 - } - return 0 -} - -func truncate(s string, maxLen int) string { - if len(s) <= maxLen { - return s - } - return s[:maxLen] + "..." -} - -func scanNodes(rows *sql.Rows) ([]Node, error) { - var nodes []Node - for rows.Next() { - var n Node - err := rows.Scan( - &n.ID, &n.Kind, &n.Name, &n.QualifiedName, &n.FilePath, &n.Language, - &n.StartLine, &n.EndLine, &n.Signature, &n.Docstring, &n.Visibility, &n.IsExported, - ) - if err != nil { - continue - } - nodes = append(nodes, n) - } - return nodes, nil -} - -// SyncResult holds the result of an incremental sync. -type SyncResult struct { - FilesChecked int `json:"files_checked"` - FilesAdded int `json:"files_added"` - FilesModified int `json:"files_modified"` - FilesRemoved int `json:"files_removed"` - NodesUpdated int `json:"nodes_updated"` - DurationMs int `json:"duration_ms"` -} - -// Sync performs an incremental sync — only re-indexes files whose content hash -// has changed since the last index. Removes files that no longer exist. -func (cg *CodeGraph) Sync() (*SyncResult, error) { - start := time.Now() - result := &SyncResult{} - - // Get currently tracked files - trackedFiles := make(map[string]string) // path -> content_hash - rows, err := cg.db.QueryContext(context.Background(), "SELECT path, content_hash FROM files") - if err != nil { - return nil, err - } - for rows.Next() { - var path, hash string - rows.Scan(&path, &hash) - trackedFiles[path] = hash - } - rows.Close() - - // Scan current files - currentFiles := make(map[string]bool) - err = filepath.WalkDir(cg.root, func(path string, d os.DirEntry, err error) error { - if err != nil { - return nil - } - if d.IsDir() { - base := d.Name() - if base == ".git" || base == "node_modules" || base == "vendor" || base == "dist" || base == "build" || base == ".codegraph" || base == "target" || base == "__pycache__" || base == ".venv" { - return filepath.SkipDir - } - return nil - } - ext := filepath.Ext(path) - if cg.extracts[ext] != nil { - relPath, _ := filepath.Rel(cg.root, path) - currentFiles[relPath] = true - result.FilesChecked++ - - // Check if file changed - source, err := os.ReadFile(path) - if err != nil { - return nil - } - hash := sha256Sum(source) - - oldHash, exists := trackedFiles[relPath] - if !exists { - // New file - if err := cg.IndexFile(path); err == nil { - result.FilesAdded++ - result.NodesUpdated++ - } - } else if oldHash != hash { - // Modified file - if err := cg.IndexFile(path); err == nil { - result.FilesModified++ - result.NodesUpdated++ - } - } - // else: unchanged, skip - } - return nil - }) - if err != nil { - return nil, err - } - - // Remove files that no longer exist - cg.mu.Lock() - for trackedPath := range trackedFiles { - if !currentFiles[trackedPath] { - absPath := filepath.Join(cg.root, trackedPath) - relForDelete := trackedPath - cg.db.ExecContext(context.Background(), "DELETE FROM nodes WHERE file_path = ?", absPath) - cg.db.ExecContext(context.Background(), "DELETE FROM edges WHERE source IN (SELECT id FROM nodes WHERE file_path = ?)", absPath) - cg.db.ExecContext(context.Background(), "DELETE FROM files WHERE path = ?", relForDelete) - result.FilesRemoved++ - } - } - cg.mu.Unlock() - - result.DurationMs = int(time.Since(start).Milliseconds()) - return result, nil -} - -// Trace finds the shortest call path between two symbols. -// Returns the chain of nodes from 'from' to 'to', or nil if no path exists. -func (cg *CodeGraph) Trace(fromName, toName string) ([]Node, error) { - cg.mu.RLock() - defer cg.mu.RUnlock() - - // Find source nodes - fromNodes, err := cg.searchByName(fromName, 5) - if err != nil || len(fromNodes) == 0 { - return nil, fmt.Errorf("symbol %q not found", fromName) - } - - // Find target nodes - toNodes, err := cg.searchByName(toName, 5) - if err != nil || len(toNodes) == 0 { - return nil, fmt.Errorf("symbol %q not found", toName) - } - - toIDs := make(map[string]bool) - for _, n := range toNodes { - toIDs[n.ID] = true - } - - // BFS from each source to find shortest path - type step struct { - nodeID string - path []string - } - - for _, from := range fromNodes { - visited := make(map[string]bool) - queue := []step{{nodeID: from.ID, path: []string{from.ID}}} - - for len(queue) > 0 { - current := queue[0] - queue = queue[1:] - - if visited[current.nodeID] { - continue - } - visited[current.nodeID] = true - - if toIDs[current.nodeID] { - // Found path — load full nodes - var path []Node - for _, id := range current.path { - var n Node - err := cg.db.QueryRowContext( - context.Background(), - `SELECT id, kind, name, qualified_name, file_path, language, - start_line, end_line, signature, docstring, visibility, is_exported - FROM nodes WHERE id = ?`, id, - ).Scan( - &n.ID, &n.Kind, &n.Name, &n.QualifiedName, &n.FilePath, &n.Language, - &n.StartLine, &n.EndLine, &n.Signature, &n.Docstring, &n.Visibility, &n.IsExported, - ) - if err == nil { - path = append(path, n) - } - } - return path, nil - } - - // Expand via call edges - edgeRows, _ := cg.db.QueryContext( - context.Background(), - `SELECT target FROM edges WHERE source = ? AND kind IN ('calls', 'references') LIMIT 20`, current.nodeID, - ) - if edgeRows != nil { - for edgeRows.Next() { - var nextID string - edgeRows.Scan(&nextID) - if !visited[nextID] { - newPath := make([]string, len(current.path)+1) - copy(newPath, current.path) - newPath[len(current.path)] = nextID - queue = append(queue, step{nodeID: nextID, path: newPath}) - } - } - edgeRows.Close() - } - } - } - - return nil, fmt.Errorf("no call path from %q to %q", fromName, toName) -} - -// ExploreResult holds source code for multiple symbols grouped by file. -type ExploreResult struct { - Files map[string][]Node `json:"files"` - SourceLines map[string]string `json:"source_lines"` // file:line -> source snippet -} - -// Explore returns source code for several related symbols grouped by file. -func (cg *CodeGraph) Explore(query string, maxFiles int) (*ExploreResult, error) { - if maxFiles <= 0 { - maxFiles = 10 - } - - // Search for symbols - nodes, err := cg.Search(query, maxFiles*3) - if err != nil { - return nil, err - } - if len(nodes) == 0 { - return nil, fmt.Errorf("no symbols found for %q", query) - } - - // Group by file - byFile := make(map[string][]Node) - for _, n := range nodes { - byFile[n.FilePath] = append(byFile[n.FilePath], n) - } - - // Limit files - result := &ExploreResult{ - Files: make(map[string][]Node), - SourceLines: make(map[string]string), - } - - count := 0 - for filePath, fileNodes := range byFile { - if count >= maxFiles { - break - } - - // Read source file - absPath := filePath - if !filepath.IsAbs(absPath) { - absPath = filepath.Join(cg.root, filePath) - } - source, err := os.ReadFile(absPath) - if err != nil { - continue - } - lines := strings.Split(string(source), "\n") - - result.Files[filePath] = fileNodes - - // Extract source snippets for each node - for _, n := range fileNodes { - startIdx := n.StartLine - 1 - endIdx := n.EndLine - if startIdx >= 0 && endIdx <= len(lines) { - snippet := strings.Join(lines[startIdx:endIdx], "\n") - if len(snippet) > 2000 { - snippet = snippet[:2000] + "\n... (truncated)" - } - key := fmt.Sprintf("%s:%d", filePath, n.StartLine) - result.SourceLines[key] = snippet - } - } - count++ - } - - return result, nil -} - -// FileEntry represents a tracked file in the index. -type FileEntry struct { - Path string `json:"path"` - Language string `json:"language"` - Size int `json:"size"` - NodeCount int `json:"node_count"` - IndexedAt int `json:"indexed_at"` -} - -// Files returns the list of all indexed files. -func (cg *CodeGraph) Files(dirFilter string) ([]FileEntry, error) { - cg.mu.RLock() - defer cg.mu.RUnlock() - - query := "SELECT path, language, size, node_count, indexed_at FROM files" - args := []interface{}{} - - if dirFilter != "" { - query += " WHERE path LIKE ?" - args = append(args, dirFilter+"%") - } - query += " ORDER BY path" - - rows, err := cg.db.QueryContext(context.Background(), query, args...) - if err != nil { - return nil, err - } - defer rows.Close() //nolint:errcheck // deferred Close on read-only rows is cleanup-only - - var files []FileEntry - for rows.Next() { - var f FileEntry - rows.Scan(&f.Path, &f.Language, &f.Size, &f.NodeCount, &f.IndexedAt) - files = append(files, f) - } - return files, nil -} - -// StatusResult holds detailed index health information. -type StatusResult struct { - ProjectRoot string `json:"project_root"` - DBPath string `json:"db_path"` - DBSizeBytes int64 `json:"db_size_bytes"` - Files int `json:"files"` - Nodes int `json:"nodes"` - Edges int `json:"edges"` - Unresolved int `json:"unresolved_refs"` - NodesByKind map[string]int `json:"nodes_by_kind"` - FilesByLang map[string]int `json:"files_by_lang"` - JournalMode string `json:"journal_mode"` - UpToDate bool `json:"up_to_date"` -} - -// Status returns detailed index health and statistics. -func (cg *CodeGraph) Status() (*StatusResult, error) { - cg.mu.RLock() - defer cg.mu.RUnlock() - - status := &StatusResult{ - ProjectRoot: cg.root, - DBPath: filepath.Join(cg.root, ".codegraph", "codegraph.db"), - NodesByKind: make(map[string]int), - FilesByLang: make(map[string]int), - } - - // DB size - if info, err := os.Stat(status.DBPath); err == nil { - status.DBSizeBytes = info.Size() - } - - // Counts - cg.db.QueryRowContext(context.Background(), "SELECT COUNT(*) FROM nodes").Scan(&status.Nodes) - cg.db.QueryRowContext(context.Background(), "SELECT COUNT(*) FROM edges").Scan(&status.Edges) - cg.db.QueryRowContext(context.Background(), "SELECT COUNT(*) FROM files").Scan(&status.Files) - cg.db.QueryRowContext(context.Background(), "SELECT COUNT(*) FROM unresolved_refs").Scan(&status.Unresolved) - - // Nodes by kind - rows, _ := cg.db.QueryContext(context.Background(), "SELECT kind, COUNT(*) FROM nodes GROUP BY kind ORDER BY COUNT(*) DESC") - if rows != nil { - for rows.Next() { - var kind string - var count int - rows.Scan(&kind, &count) - status.NodesByKind[kind] = count - } - rows.Close() - } - - // Files by language - rows, _ = cg.db.QueryContext(context.Background(), "SELECT language, COUNT(*) FROM files GROUP BY language ORDER BY COUNT(*) DESC") - if rows != nil { - for rows.Next() { - var lang string - var count int - rows.Scan(&lang, &count) - status.FilesByLang[lang] = count - } - rows.Close() - } - - // Journal mode - cg.db.QueryRowContext(context.Background(), "PRAGMA journal_mode").Scan(&status.JournalMode) - - // Check if up to date (no pending changes) - status.UpToDate = true - fileRows, _ := cg.db.QueryContext(context.Background(), "SELECT path, content_hash FROM files") - if fileRows != nil { - for fileRows.Next() { - var path, hash string - fileRows.Scan(&path, &hash) - absPath := filepath.Join(cg.root, path) - source, err := os.ReadFile(absPath) - if err != nil { - status.UpToDate = false - break - } - if sha256Sum(source) != hash { - status.UpToDate = false - break - } - } - fileRows.Close() - } - - return status, nil -} - -// searchByName is an internal search that returns nodes matching a name. -func (cg *CodeGraph) searchByName(name string, limit int) ([]Node, error) { - rows, err := cg.db.QueryContext( - context.Background(), - `SELECT id, kind, name, qualified_name, file_path, language, - start_line, end_line, signature, docstring, visibility, is_exported - FROM nodes WHERE name = ? OR name LIKE ? LIMIT ?`, - name, "%"+name+"%", limit, - ) - if err != nil { - return nil, err - } - defer rows.Close() //nolint:errcheck // deferred Close on read-only rows is cleanup-only - return scanNodes(rows) -} - -// GetNode returns a single node by ID. -func (cg *CodeGraph) GetNode(id string) (Node, error) { - cg.mu.RLock() - defer cg.mu.RUnlock() - - var n Node - err := cg.db.QueryRowContext( - context.Background(), - `SELECT id, kind, name, qualified_name, file_path, language, - start_line, end_line, signature, docstring, visibility, is_exported - FROM nodes WHERE id = ?`, id, - ).Scan( - &n.ID, &n.Kind, &n.Name, &n.QualifiedName, &n.FilePath, &n.Language, - &n.StartLine, &n.EndLine, &n.Signature, &n.Docstring, &n.Visibility, &n.IsExported, - ) - return n, err -} +// Standalone extraction/formatting helpers and the high-level read/traversal +// operations (Sync, Trace, Explore, Files, Status, searchByName, GetNode) live +// in codegraph_cgo_query.go. diff --git a/internal/codegraph/codegraph_cgo_query.go b/internal/codegraph/codegraph_cgo_query.go new file mode 100644 index 00000000..cf0b3d95 --- /dev/null +++ b/internal/codegraph/codegraph_cgo_query.go @@ -0,0 +1,593 @@ +//go:build cgo + +package codegraph + +import ( + "context" + "crypto/sha256" + "database/sql" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + sitter "github.com/smacker/go-tree-sitter" + "github.com/smacker/go-tree-sitter/golang" + "github.com/smacker/go-tree-sitter/python" + "github.com/smacker/go-tree-sitter/typescript/tsx" + tstype "github.com/smacker/go-tree-sitter/typescript/typescript" +) + +// This file holds the standalone extraction/formatting helpers and the +// high-level read/traversal operations (Sync, Trace, Explore, Files, Status, +// searchByName, GetNode) for the CodeGraph. The CodeGraph type, schema setup, +// indexing, and the core query/traversal methods live in codegraph_cgo.go. + +func extractGoSignature(node *sitter.Node, source []byte) string { + // Extract the full function signature line + start := node.StartByte() + bodyNode := node.ChildByFieldName("body") + if bodyNode != nil { + // Get everything before the body + return strings.TrimSpace(string(source[start:bodyNode.StartByte()])) + } + // Fallback: get the first line + text := string(source[start:node.EndByte()]) + lines := strings.SplitN(text, "\n", 2) + return strings.TrimSpace(lines[0]) +} + +func extractCalleeName(node *sitter.Node, source []byte) string { + // For call_expression, get the function name + funcNode := node.ChildByFieldName("function") + if funcNode == nil { + // Try first child + if node.NamedChildCount() > 0 { + funcNode = node.NamedChild(0) + } + } + if funcNode == nil { + return "" + } + + text := string(source[funcNode.StartByte():funcNode.EndByte()]) + // Remove arguments part + if idx := strings.Index(text, "("); idx > 0 { + text = text[:idx] + } + return strings.TrimSpace(text) +} + +func extractDocstring(node *sitter.Node, source []byte) string { + // Look for comment node before this node + parent := node.Parent() + if parent == nil { + return "" + } + + for i := 0; i < int(parent.NamedChildCount()); i++ { + child := parent.NamedChild(i) + if child.Equal(node) && i > 0 { + prev := parent.NamedChild(i - 1) + if prev.Type() == "comment" || prev.Type() == "block_comment" { + text := string(source[prev.StartByte():prev.EndByte()]) + // Clean comment markers + text = strings.TrimPrefix(text, "//") + text = strings.TrimPrefix(text, "/*") + text = strings.TrimSuffix(text, "*/") + return strings.TrimSpace(text) + } + } + } + return "" +} + +func generateNodeID(filePath, kind, name string, line int) string { + h := sha256.Sum256([]byte(fmt.Sprintf("%s:%s:%s:%d", filePath, kind, name, line))) + return fmt.Sprintf("%x", h[:8]) +} + +func sha256Sum(data []byte) string { + h := sha256.Sum256(data) + return fmt.Sprintf("%x", h[:8]) +} + +func getLanguage(ext string) *sitter.Language { + switch ext { + case ".go": + return golang.GetLanguage() + case ".py": + return python.GetLanguage() + case ".ts": + return tstype.GetLanguage() + case ".tsx", ".js", ".jsx": + return tsx.GetLanguage() + default: + return golang.GetLanguage() // fallback + } +} + +func extractSearchTerms(query string) []string { + // Split on spaces and camelCase boundaries + var terms []string + words := strings.Fields(query) + for _, w := range words { + // Split camelCase + parts := splitCamelCase(w) + terms = append(terms, parts...) + } + return terms +} + +func splitCamelCase(s string) []string { + var parts []string + var current strings.Builder + + for i, r := range s { + if i > 0 && r >= 'A' && r <= 'Z' { + if current.Len() > 0 { + parts = append(parts, current.String()) + current.Reset() + } + } + current.WriteRune(r) + } + if current.Len() > 0 { + parts = append(parts, current.String()) + } + return parts +} + +func contains(slice []string, item string) bool { + for _, s := range slice { + if s == item { + return true + } + } + return false +} + +func boolToInt(b bool) int { + if b { + return 1 + } + return 0 +} + +func truncate(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen] + "..." +} + +func scanNodes(rows *sql.Rows) ([]Node, error) { + var nodes []Node + for rows.Next() { + var n Node + err := rows.Scan( + &n.ID, &n.Kind, &n.Name, &n.QualifiedName, &n.FilePath, &n.Language, + &n.StartLine, &n.EndLine, &n.Signature, &n.Docstring, &n.Visibility, &n.IsExported, + ) + if err != nil { + continue + } + nodes = append(nodes, n) + } + return nodes, nil +} + +// SyncResult holds the result of an incremental sync. +type SyncResult struct { + FilesChecked int `json:"files_checked"` + FilesAdded int `json:"files_added"` + FilesModified int `json:"files_modified"` + FilesRemoved int `json:"files_removed"` + NodesUpdated int `json:"nodes_updated"` + DurationMs int `json:"duration_ms"` +} + +// Sync performs an incremental sync — only re-indexes files whose content hash +// has changed since the last index. Removes files that no longer exist. +func (cg *CodeGraph) Sync() (*SyncResult, error) { + start := time.Now() + result := &SyncResult{} + + // Get currently tracked files + trackedFiles := make(map[string]string) // path -> content_hash + rows, err := cg.db.QueryContext(context.Background(), "SELECT path, content_hash FROM files") + if err != nil { + return nil, err + } + for rows.Next() { + var path, hash string + rows.Scan(&path, &hash) + trackedFiles[path] = hash + } + rows.Close() + + // Scan current files + currentFiles := make(map[string]bool) + err = filepath.WalkDir(cg.root, func(path string, d os.DirEntry, err error) error { + if err != nil { + return nil + } + if d.IsDir() { + base := d.Name() + if base == ".git" || base == "node_modules" || base == "vendor" || base == "dist" || base == "build" || base == ".codegraph" || base == "target" || base == "__pycache__" || base == ".venv" { + return filepath.SkipDir + } + return nil + } + ext := filepath.Ext(path) + if cg.extracts[ext] != nil { + relPath, _ := filepath.Rel(cg.root, path) + currentFiles[relPath] = true + result.FilesChecked++ + + // Check if file changed + source, err := os.ReadFile(path) + if err != nil { + return nil + } + hash := sha256Sum(source) + + oldHash, exists := trackedFiles[relPath] + if !exists { + // New file + if err := cg.IndexFile(path); err == nil { + result.FilesAdded++ + result.NodesUpdated++ + } + } else if oldHash != hash { + // Modified file + if err := cg.IndexFile(path); err == nil { + result.FilesModified++ + result.NodesUpdated++ + } + } + // else: unchanged, skip + } + return nil + }) + if err != nil { + return nil, err + } + + // Remove files that no longer exist + cg.mu.Lock() + for trackedPath := range trackedFiles { + if !currentFiles[trackedPath] { + absPath := filepath.Join(cg.root, trackedPath) + relForDelete := trackedPath + cg.db.ExecContext(context.Background(), "DELETE FROM nodes WHERE file_path = ?", absPath) + cg.db.ExecContext(context.Background(), "DELETE FROM edges WHERE source IN (SELECT id FROM nodes WHERE file_path = ?)", absPath) + cg.db.ExecContext(context.Background(), "DELETE FROM files WHERE path = ?", relForDelete) + result.FilesRemoved++ + } + } + cg.mu.Unlock() + + result.DurationMs = int(time.Since(start).Milliseconds()) + return result, nil +} + +// Trace finds the shortest call path between two symbols. +// Returns the chain of nodes from 'from' to 'to', or nil if no path exists. +func (cg *CodeGraph) Trace(fromName, toName string) ([]Node, error) { + cg.mu.RLock() + defer cg.mu.RUnlock() + + // Find source nodes + fromNodes, err := cg.searchByName(fromName, 5) + if err != nil || len(fromNodes) == 0 { + return nil, fmt.Errorf("symbol %q not found", fromName) + } + + // Find target nodes + toNodes, err := cg.searchByName(toName, 5) + if err != nil || len(toNodes) == 0 { + return nil, fmt.Errorf("symbol %q not found", toName) + } + + toIDs := make(map[string]bool) + for _, n := range toNodes { + toIDs[n.ID] = true + } + + // BFS from each source to find shortest path + type step struct { + nodeID string + path []string + } + + for _, from := range fromNodes { + visited := make(map[string]bool) + queue := []step{{nodeID: from.ID, path: []string{from.ID}}} + + for len(queue) > 0 { + current := queue[0] + queue = queue[1:] + + if visited[current.nodeID] { + continue + } + visited[current.nodeID] = true + + if toIDs[current.nodeID] { + // Found path — load full nodes + var path []Node + for _, id := range current.path { + var n Node + err := cg.db.QueryRowContext( + context.Background(), + `SELECT id, kind, name, qualified_name, file_path, language, + start_line, end_line, signature, docstring, visibility, is_exported + FROM nodes WHERE id = ?`, id, + ).Scan( + &n.ID, &n.Kind, &n.Name, &n.QualifiedName, &n.FilePath, &n.Language, + &n.StartLine, &n.EndLine, &n.Signature, &n.Docstring, &n.Visibility, &n.IsExported, + ) + if err == nil { + path = append(path, n) + } + } + return path, nil + } + + // Expand via call edges + edgeRows, _ := cg.db.QueryContext( + context.Background(), + `SELECT target FROM edges WHERE source = ? AND kind IN ('calls', 'references') LIMIT 20`, current.nodeID, + ) + if edgeRows != nil { + for edgeRows.Next() { + var nextID string + edgeRows.Scan(&nextID) + if !visited[nextID] { + newPath := make([]string, len(current.path)+1) + copy(newPath, current.path) + newPath[len(current.path)] = nextID + queue = append(queue, step{nodeID: nextID, path: newPath}) + } + } + edgeRows.Close() + } + } + } + + return nil, fmt.Errorf("no call path from %q to %q", fromName, toName) +} + +// ExploreResult holds source code for multiple symbols grouped by file. +type ExploreResult struct { + Files map[string][]Node `json:"files"` + SourceLines map[string]string `json:"source_lines"` // file:line -> source snippet +} + +// Explore returns source code for several related symbols grouped by file. +func (cg *CodeGraph) Explore(query string, maxFiles int) (*ExploreResult, error) { + if maxFiles <= 0 { + maxFiles = 10 + } + + // Search for symbols + nodes, err := cg.Search(query, maxFiles*3) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, fmt.Errorf("no symbols found for %q", query) + } + + // Group by file + byFile := make(map[string][]Node) + for _, n := range nodes { + byFile[n.FilePath] = append(byFile[n.FilePath], n) + } + + // Limit files + result := &ExploreResult{ + Files: make(map[string][]Node), + SourceLines: make(map[string]string), + } + + count := 0 + for filePath, fileNodes := range byFile { + if count >= maxFiles { + break + } + + // Read source file + absPath := filePath + if !filepath.IsAbs(absPath) { + absPath = filepath.Join(cg.root, filePath) + } + source, err := os.ReadFile(absPath) + if err != nil { + continue + } + lines := strings.Split(string(source), "\n") + + result.Files[filePath] = fileNodes + + // Extract source snippets for each node + for _, n := range fileNodes { + startIdx := n.StartLine - 1 + endIdx := n.EndLine + if startIdx >= 0 && endIdx <= len(lines) { + snippet := strings.Join(lines[startIdx:endIdx], "\n") + if len(snippet) > 2000 { + snippet = snippet[:2000] + "\n... (truncated)" + } + key := fmt.Sprintf("%s:%d", filePath, n.StartLine) + result.SourceLines[key] = snippet + } + } + count++ + } + + return result, nil +} + +// FileEntry represents a tracked file in the index. +type FileEntry struct { + Path string `json:"path"` + Language string `json:"language"` + Size int `json:"size"` + NodeCount int `json:"node_count"` + IndexedAt int `json:"indexed_at"` +} + +// Files returns the list of all indexed files. +func (cg *CodeGraph) Files(dirFilter string) ([]FileEntry, error) { + cg.mu.RLock() + defer cg.mu.RUnlock() + + query := "SELECT path, language, size, node_count, indexed_at FROM files" + args := []interface{}{} + + if dirFilter != "" { + query += " WHERE path LIKE ?" + args = append(args, dirFilter+"%") + } + query += " ORDER BY path" + + rows, err := cg.db.QueryContext(context.Background(), query, args...) + if err != nil { + return nil, err + } + defer rows.Close() //nolint:errcheck // deferred Close on read-only rows is cleanup-only + + var files []FileEntry + for rows.Next() { + var f FileEntry + rows.Scan(&f.Path, &f.Language, &f.Size, &f.NodeCount, &f.IndexedAt) + files = append(files, f) + } + return files, nil +} + +// StatusResult holds detailed index health information. +type StatusResult struct { + ProjectRoot string `json:"project_root"` + DBPath string `json:"db_path"` + DBSizeBytes int64 `json:"db_size_bytes"` + Files int `json:"files"` + Nodes int `json:"nodes"` + Edges int `json:"edges"` + Unresolved int `json:"unresolved_refs"` + NodesByKind map[string]int `json:"nodes_by_kind"` + FilesByLang map[string]int `json:"files_by_lang"` + JournalMode string `json:"journal_mode"` + UpToDate bool `json:"up_to_date"` +} + +// Status returns detailed index health and statistics. +func (cg *CodeGraph) Status() (*StatusResult, error) { + cg.mu.RLock() + defer cg.mu.RUnlock() + + status := &StatusResult{ + ProjectRoot: cg.root, + DBPath: filepath.Join(cg.root, ".codegraph", "codegraph.db"), + NodesByKind: make(map[string]int), + FilesByLang: make(map[string]int), + } + + // DB size + if info, err := os.Stat(status.DBPath); err == nil { + status.DBSizeBytes = info.Size() + } + + // Counts + cg.db.QueryRowContext(context.Background(), "SELECT COUNT(*) FROM nodes").Scan(&status.Nodes) + cg.db.QueryRowContext(context.Background(), "SELECT COUNT(*) FROM edges").Scan(&status.Edges) + cg.db.QueryRowContext(context.Background(), "SELECT COUNT(*) FROM files").Scan(&status.Files) + cg.db.QueryRowContext(context.Background(), "SELECT COUNT(*) FROM unresolved_refs").Scan(&status.Unresolved) + + // Nodes by kind + rows, _ := cg.db.QueryContext(context.Background(), "SELECT kind, COUNT(*) FROM nodes GROUP BY kind ORDER BY COUNT(*) DESC") + if rows != nil { + for rows.Next() { + var kind string + var count int + rows.Scan(&kind, &count) + status.NodesByKind[kind] = count + } + rows.Close() + } + + // Files by language + rows, _ = cg.db.QueryContext(context.Background(), "SELECT language, COUNT(*) FROM files GROUP BY language ORDER BY COUNT(*) DESC") + if rows != nil { + for rows.Next() { + var lang string + var count int + rows.Scan(&lang, &count) + status.FilesByLang[lang] = count + } + rows.Close() + } + + // Journal mode + cg.db.QueryRowContext(context.Background(), "PRAGMA journal_mode").Scan(&status.JournalMode) + + // Check if up to date (no pending changes) + status.UpToDate = true + fileRows, _ := cg.db.QueryContext(context.Background(), "SELECT path, content_hash FROM files") + if fileRows != nil { + for fileRows.Next() { + var path, hash string + fileRows.Scan(&path, &hash) + absPath := filepath.Join(cg.root, path) + source, err := os.ReadFile(absPath) + if err != nil { + status.UpToDate = false + break + } + if sha256Sum(source) != hash { + status.UpToDate = false + break + } + } + fileRows.Close() + } + + return status, nil +} + +// searchByName is an internal search that returns nodes matching a name. +func (cg *CodeGraph) searchByName(name string, limit int) ([]Node, error) { + rows, err := cg.db.QueryContext( + context.Background(), + `SELECT id, kind, name, qualified_name, file_path, language, + start_line, end_line, signature, docstring, visibility, is_exported + FROM nodes WHERE name = ? OR name LIKE ? LIMIT ?`, + name, "%"+name+"%", limit, + ) + if err != nil { + return nil, err + } + defer rows.Close() //nolint:errcheck // deferred Close on read-only rows is cleanup-only + return scanNodes(rows) +} + +// GetNode returns a single node by ID. +func (cg *CodeGraph) GetNode(id string) (Node, error) { + cg.mu.RLock() + defer cg.mu.RUnlock() + + var n Node + err := cg.db.QueryRowContext( + context.Background(), + `SELECT id, kind, name, qualified_name, file_path, language, + start_line, end_line, signature, docstring, visibility, is_exported + FROM nodes WHERE id = ?`, id, + ).Scan( + &n.ID, &n.Kind, &n.Name, &n.QualifiedName, &n.FilePath, &n.Language, + &n.StartLine, &n.EndLine, &n.Signature, &n.Docstring, &n.Visibility, &n.IsExported, + ) + return n, err +} From 3e29ef2c696f540a409c6f66eebd923a3caef4ce Mon Sep 17 00:00:00 2001 From: Lakshman Patel Date: Fri, 19 Jun 2026 19:33:20 +0530 Subject: [PATCH 16/20] refactor(codegraph): split ranking/impact/coupling/cross-repo algorithms into algorithms_cgo_more.go --- internal/codegraph/algorithms_cgo.go | 391 +-------------------- internal/codegraph/algorithms_cgo_more.go | 402 ++++++++++++++++++++++ 2 files changed, 405 insertions(+), 388 deletions(-) create mode 100644 internal/codegraph/algorithms_cgo_more.go diff --git a/internal/codegraph/algorithms_cgo.go b/internal/codegraph/algorithms_cgo.go index acf79e07..14e95492 100644 --- a/internal/codegraph/algorithms_cgo.go +++ b/internal/codegraph/algorithms_cgo.go @@ -660,391 +660,6 @@ func computeModularity(adj map[string]map[string]float64, community map[string]i return q / (2 * totalWeight) } -// PageRank computes PageRank on the code graph's call/reference edges. -// This is more accurate than repomap's PageRank because it uses the precise -// call graph from tree-sitter parsing rather than string matching. -func (cg *CodeGraph) PageRank(iterations int, damping float64) (map[string]float64, error) { - if iterations <= 0 { - iterations = 20 - } - if damping <= 0 { - damping = 0.85 - } - - cg.mu.RLock() - defer cg.mu.RUnlock() - - // Build adjacency - outlinks := make(map[string][]string) - inlinks := make(map[string][]string) - nodes := make(map[string]bool) - - rows, err := cg.db.QueryContext(context.Background(), "SELECT source, target FROM edges WHERE kind IN ('calls', 'references', 'imports')") - if err != nil { - return nil, err - } - defer rows.Close() //nolint:errcheck // deferred Close on read-only rows is cleanup-only - - for rows.Next() { - var src, tgt string - if err := rows.Scan(&src, &tgt); err != nil { - return nil, fmt.Errorf("scanning edge row: %w", err) - } - outlinks[src] = append(outlinks[src], tgt) - inlinks[tgt] = append(inlinks[tgt], src) - nodes[src] = true - nodes[tgt] = true - } - if err := rows.Err(); err != nil { - return nil, fmt.Errorf("iterating edges: %w", err) - } - - n := float64(len(nodes)) - if n == 0 { - return make(map[string]float64), nil - } - - // Initialize ranks - rank := make(map[string]float64) - for id := range nodes { - rank[id] = 1.0 / n - } - - // Iterate - for iter := 0; iter < iterations; iter++ { - newRank := make(map[string]float64) - - // Collect dangling rank (nodes with no outlinks) - danglingSum := 0.0 - for id := range nodes { - if len(outlinks[id]) == 0 { - danglingSum += rank[id] - } - } - - for id := range nodes { - sum := 0.0 - for _, src := range inlinks[id] { - sum += rank[src] / float64(len(outlinks[src])) - } - newRank[id] = (1-damping)/n + damping*(sum+danglingSum/n) - } - - rank = newRank - } - - return rank, nil -} - -// ImpactAnalysis computes the blast radius of changing a symbol. -// Uses the full call graph to find all directly and transitively affected nodes. -func (cg *CodeGraph) ImpactAnalysis(nodeID string, maxDepth int) (*ImpactResult, error) { - if maxDepth <= 0 { - maxDepth = 3 - } - - cg.mu.RLock() - defer cg.mu.RUnlock() - - result := &ImpactResult{ - Root: nodeID, - Impacted: make(map[string]int), // nodeID -> depth - } - - // BFS from the changed node through incoming edges - visited := make(map[string]bool) - type step struct { - id string - depth int - } - queue := []step{{nodeID, 0}} - - for len(queue) > 0 { - s := queue[0] - queue = queue[1:] - - if visited[s.id] || s.depth > maxDepth { - continue - } - visited[s.id] = true - result.Impacted[s.id] = s.depth - - // Get all nodes that depend on this one - rows, _ := cg.db.QueryContext( - context.Background(), - `SELECT source FROM edges WHERE target = ? AND kind IN ('calls', 'references', 'imports', 'extends', 'implements')`, s.id, - ) - if rows != nil { - for rows.Next() { - var source string - if err := rows.Scan(&source); err != nil { - rows.Close() - return nil, fmt.Errorf("scanning dependency row for %s: %w", s.id, err) - } - if !visited[source] { - queue = append(queue, step{source, s.depth + 1}) - } - } - rows.Close() - } - } - - // Load node details - for id, depth := range result.Impacted { - var n Node - err := cg.db.QueryRowContext( - context.Background(), - `SELECT id, kind, name, qualified_name, file_path, language, - start_line, end_line, signature, docstring, visibility, is_exported - FROM nodes WHERE id = ?`, id, - ).Scan( - &n.ID, &n.Kind, &n.Name, &n.QualifiedName, &n.FilePath, &n.Language, - &n.StartLine, &n.EndLine, &n.Signature, &n.Docstring, &n.Visibility, &n.IsExported, - ) - if err == nil { - result.Nodes = append(result.Nodes, n) - if depth > result.MaxDepth { - result.MaxDepth = depth - } - } - } - - // Sort by depth - sort.Slice(result.Nodes, func(i, j int) bool { - return result.Impacted[result.Nodes[i].ID] < result.Impacted[result.Nodes[j].ID] - }) - - return result, nil -} - -// ImpactResult holds the result of impact analysis. -type ImpactResult struct { - Root string `json:"root"` - Impacted map[string]int `json:"impacted"` // nodeID -> depth - Nodes []Node `json:"nodes"` - MaxDepth int `json:"max_depth"` -} - -// CouplingMetric represents coupling between two modules/files. -type CouplingMetric struct { - FileA string `json:"file_a"` - FileB string `json:"file_b"` - SharedDeps int `json:"shared_deps"` // number of shared dependencies - Coupling float64 `json:"coupling"` // 0-1 coupling score -} - -// AnalyzeCoupling finds pairs of files that are tightly coupled (share many dependencies). -func (cg *CodeGraph) AnalyzeCoupling(topN int) ([]CouplingMetric, error) { - if topN <= 0 { - topN = 10 - } - - cg.mu.RLock() - defer cg.mu.RUnlock() - - // Build file -> set of referenced symbols - fileDeps := make(map[string]map[string]bool) - rows, err := cg.db.QueryContext(context.Background(), "SELECT file_path, target FROM edges e JOIN nodes n ON n.id = e.source WHERE e.kind IN ('calls', 'references', 'imports')") - if err != nil { - return nil, err - } - defer rows.Close() //nolint:errcheck // deferred Close on read-only rows is cleanup-only - - for rows.Next() { - var filePath, target string - if err := rows.Scan(&filePath, &target); err != nil { - return nil, fmt.Errorf("scanning file dependency row: %w", err) - } - if fileDeps[filePath] == nil { - fileDeps[filePath] = make(map[string]bool) - } - fileDeps[filePath][target] = true - } - if err := rows.Err(); err != nil { - return nil, fmt.Errorf("iterating file dependencies: %w", err) - } - - // Compute pairwise coupling - var metrics []CouplingMetric - files := make([]string, 0, len(fileDeps)) - for f := range fileDeps { - files = append(files, f) - } - - for i := 0; i < len(files); i++ { - for j := i + 1; j < len(files); j++ { - shared := 0 - for dep := range fileDeps[files[i]] { - if fileDeps[files[j]][dep] { - shared++ - } - } - if shared > 0 { - total := len(fileDeps[files[i]]) + len(fileDeps[files[j]]) - coupling := float64(shared*2) / float64(total) - metrics = append(metrics, CouplingMetric{ - FileA: files[i], - FileB: files[j], - SharedDeps: shared, - Coupling: coupling, - }) - } - } - } - - sort.Slice(metrics, func(i, j int) bool { - return metrics[i].Coupling > metrics[j].Coupling - }) - - if topN > len(metrics) { - topN = len(metrics) - } - return metrics[:topN], nil -} - -// CrossRepoQuery queries across multiple codegraph databases. -// Useful for finding relationships between hawk, eyrie, tok, yaad, etc. -func CrossRepoQuery(repos []string, query string, limit int) (map[string][]Node, error) { - results := make(map[string][]Node) - - for _, repoRoot := range repos { - cg, err := Open(repoRoot) - if err != nil { - continue // Skip repos without codegraph - } - - nodes, err := cg.Search(query, limit) - cg.Close() - if err != nil || len(nodes) == 0 { - continue - } - - results[repoRoot] = nodes - } - - return results, nil -} - -// CrossRepoImpact finds the impact of changing a symbol across multiple repos. -// If a symbol in hawk calls a symbol in eyrie, this traces that cross-repo dependency. -func CrossRepoImpact(repos []string, symbol string, maxDepth int) (map[string]*ImpactResult, error) { - results := make(map[string]*ImpactResult) - - for _, repoRoot := range repos { - cg, err := Open(repoRoot) - if err != nil { - continue - } - - // Search for the symbol - nodes, err := cg.Search(symbol, 5) - if err != nil || len(nodes) == 0 { - cg.Close() - continue - } - - // Get impact for each matching symbol - for _, n := range nodes { - impact, err := cg.ImpactAnalysis(n.ID, maxDepth) - if err != nil { - continue - } - results[repoRoot+":"+n.Name] = impact - } - - cg.Close() - } - - return results, nil -} - -// FindCrossRepoCalls finds function calls that cross repo boundaries. -// For example, hawk calling eyrie functions. -func FindCrossRepoCalls(repos []string) ([]CrossRepoCall, error) { - type repoSymbol struct { - repo string - node Node - } - - // Build a map of all symbols across repos - allSymbols := make(map[string][]repoSymbol) // name -> [{repo, node}] - repoNodes := make(map[string]map[string]bool) - - for _, repoRoot := range repos { - cg, err := Open(repoRoot) - if err != nil { - continue - } - - repoNodes[repoRoot] = make(map[string]bool) - - // Get all nodes - rows, err := cg.db.QueryContext(context.Background(), "SELECT id, kind, name, qualified_name, file_path, language, start_line, end_line, signature, docstring, visibility, is_exported FROM nodes") - if err != nil { - cg.Close() - continue - } - - nodes, _ := scanNodes(rows) - rows.Close() - - for _, n := range nodes { - allSymbols[n.Name] = append(allSymbols[n.Name], repoSymbol{repoRoot, n}) - repoNodes[repoRoot][n.ID] = true - } - - cg.Close() - } - - // Find calls that reference symbols in other repos - var crossCalls []CrossRepoCall - - for _, repoRoot := range repos { - cg, err := Open(repoRoot) - if err != nil { - continue - } - - // Get unresolved refs (calls to symbols not in this repo) - rows, err := cg.db.QueryContext(context.Background(), "SELECT from_node_id, reference_name, file_path, line FROM unresolved_refs") - if err != nil { - cg.Close() - continue - } - - for rows.Next() { - var fromID, refName, filePath string - var line int - rows.Scan(&fromID, &refName, &filePath, &line) - - // Check if this reference exists in another repo - for _, target := range allSymbols[refName] { - if target.repo != repoRoot { - crossCalls = append(crossCalls, CrossRepoCall{ - FromRepo: repoRoot, - ToRepo: target.repo, - Symbol: refName, - File: filePath, - Line: line, - Target: target.node, - }) - } - } - } - - rows.Close() - cg.Close() - } - - return crossCalls, nil -} - -// CrossRepoCall represents a function call that crosses repo boundaries. -type CrossRepoCall struct { - FromRepo string `json:"from_repo"` - ToRepo string `json:"to_repo"` - Symbol string `json:"symbol"` - File string `json:"file"` - Line int `json:"line"` - Target Node `json:"target"` -} +// Ranking, impact, coupling, and cross-repo graph algorithms (PageRank, +// ImpactAnalysis, AnalyzeCoupling, CrossRepoQuery, CrossRepoImpact, +// FindCrossRepoCalls) live in algorithms_cgo_more.go. diff --git a/internal/codegraph/algorithms_cgo_more.go b/internal/codegraph/algorithms_cgo_more.go new file mode 100644 index 00000000..dfafe46e --- /dev/null +++ b/internal/codegraph/algorithms_cgo_more.go @@ -0,0 +1,402 @@ +//go:build cgo + +package codegraph + +import ( + "context" + "fmt" + "sort" +) + +// This file holds the ranking, impact, coupling, and cross-repo graph +// algorithms. Centrality, community detection, connected components, graph +// diff/snapshot, and dead-code analysis live in algorithms_cgo.go. + +// PageRank computes PageRank on the code graph's call/reference edges. +// This is more accurate than repomap's PageRank because it uses the precise +// call graph from tree-sitter parsing rather than string matching. +func (cg *CodeGraph) PageRank(iterations int, damping float64) (map[string]float64, error) { + if iterations <= 0 { + iterations = 20 + } + if damping <= 0 { + damping = 0.85 + } + + cg.mu.RLock() + defer cg.mu.RUnlock() + + // Build adjacency + outlinks := make(map[string][]string) + inlinks := make(map[string][]string) + nodes := make(map[string]bool) + + rows, err := cg.db.QueryContext(context.Background(), "SELECT source, target FROM edges WHERE kind IN ('calls', 'references', 'imports')") + if err != nil { + return nil, err + } + defer rows.Close() //nolint:errcheck // deferred Close on read-only rows is cleanup-only + + for rows.Next() { + var src, tgt string + if err := rows.Scan(&src, &tgt); err != nil { + return nil, fmt.Errorf("scanning edge row: %w", err) + } + outlinks[src] = append(outlinks[src], tgt) + inlinks[tgt] = append(inlinks[tgt], src) + nodes[src] = true + nodes[tgt] = true + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterating edges: %w", err) + } + + n := float64(len(nodes)) + if n == 0 { + return make(map[string]float64), nil + } + + // Initialize ranks + rank := make(map[string]float64) + for id := range nodes { + rank[id] = 1.0 / n + } + + // Iterate + for iter := 0; iter < iterations; iter++ { + newRank := make(map[string]float64) + + // Collect dangling rank (nodes with no outlinks) + danglingSum := 0.0 + for id := range nodes { + if len(outlinks[id]) == 0 { + danglingSum += rank[id] + } + } + + for id := range nodes { + sum := 0.0 + for _, src := range inlinks[id] { + sum += rank[src] / float64(len(outlinks[src])) + } + newRank[id] = (1-damping)/n + damping*(sum+danglingSum/n) + } + + rank = newRank + } + + return rank, nil +} + +// ImpactAnalysis computes the blast radius of changing a symbol. +// Uses the full call graph to find all directly and transitively affected nodes. +func (cg *CodeGraph) ImpactAnalysis(nodeID string, maxDepth int) (*ImpactResult, error) { + if maxDepth <= 0 { + maxDepth = 3 + } + + cg.mu.RLock() + defer cg.mu.RUnlock() + + result := &ImpactResult{ + Root: nodeID, + Impacted: make(map[string]int), // nodeID -> depth + } + + // BFS from the changed node through incoming edges + visited := make(map[string]bool) + type step struct { + id string + depth int + } + queue := []step{{nodeID, 0}} + + for len(queue) > 0 { + s := queue[0] + queue = queue[1:] + + if visited[s.id] || s.depth > maxDepth { + continue + } + visited[s.id] = true + result.Impacted[s.id] = s.depth + + // Get all nodes that depend on this one + rows, _ := cg.db.QueryContext( + context.Background(), + `SELECT source FROM edges WHERE target = ? AND kind IN ('calls', 'references', 'imports', 'extends', 'implements')`, s.id, + ) + if rows != nil { + for rows.Next() { + var source string + if err := rows.Scan(&source); err != nil { + rows.Close() + return nil, fmt.Errorf("scanning dependency row for %s: %w", s.id, err) + } + if !visited[source] { + queue = append(queue, step{source, s.depth + 1}) + } + } + rows.Close() + } + } + + // Load node details + for id, depth := range result.Impacted { + var n Node + err := cg.db.QueryRowContext( + context.Background(), + `SELECT id, kind, name, qualified_name, file_path, language, + start_line, end_line, signature, docstring, visibility, is_exported + FROM nodes WHERE id = ?`, id, + ).Scan( + &n.ID, &n.Kind, &n.Name, &n.QualifiedName, &n.FilePath, &n.Language, + &n.StartLine, &n.EndLine, &n.Signature, &n.Docstring, &n.Visibility, &n.IsExported, + ) + if err == nil { + result.Nodes = append(result.Nodes, n) + if depth > result.MaxDepth { + result.MaxDepth = depth + } + } + } + + // Sort by depth + sort.Slice(result.Nodes, func(i, j int) bool { + return result.Impacted[result.Nodes[i].ID] < result.Impacted[result.Nodes[j].ID] + }) + + return result, nil +} + +// ImpactResult holds the result of impact analysis. +type ImpactResult struct { + Root string `json:"root"` + Impacted map[string]int `json:"impacted"` // nodeID -> depth + Nodes []Node `json:"nodes"` + MaxDepth int `json:"max_depth"` +} + +// CouplingMetric represents coupling between two modules/files. +type CouplingMetric struct { + FileA string `json:"file_a"` + FileB string `json:"file_b"` + SharedDeps int `json:"shared_deps"` // number of shared dependencies + Coupling float64 `json:"coupling"` // 0-1 coupling score +} + +// AnalyzeCoupling finds pairs of files that are tightly coupled (share many dependencies). +func (cg *CodeGraph) AnalyzeCoupling(topN int) ([]CouplingMetric, error) { + if topN <= 0 { + topN = 10 + } + + cg.mu.RLock() + defer cg.mu.RUnlock() + + // Build file -> set of referenced symbols + fileDeps := make(map[string]map[string]bool) + rows, err := cg.db.QueryContext(context.Background(), "SELECT file_path, target FROM edges e JOIN nodes n ON n.id = e.source WHERE e.kind IN ('calls', 'references', 'imports')") + if err != nil { + return nil, err + } + defer rows.Close() //nolint:errcheck // deferred Close on read-only rows is cleanup-only + + for rows.Next() { + var filePath, target string + if err := rows.Scan(&filePath, &target); err != nil { + return nil, fmt.Errorf("scanning file dependency row: %w", err) + } + if fileDeps[filePath] == nil { + fileDeps[filePath] = make(map[string]bool) + } + fileDeps[filePath][target] = true + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterating file dependencies: %w", err) + } + + // Compute pairwise coupling + var metrics []CouplingMetric + files := make([]string, 0, len(fileDeps)) + for f := range fileDeps { + files = append(files, f) + } + + for i := 0; i < len(files); i++ { + for j := i + 1; j < len(files); j++ { + shared := 0 + for dep := range fileDeps[files[i]] { + if fileDeps[files[j]][dep] { + shared++ + } + } + if shared > 0 { + total := len(fileDeps[files[i]]) + len(fileDeps[files[j]]) + coupling := float64(shared*2) / float64(total) + metrics = append(metrics, CouplingMetric{ + FileA: files[i], + FileB: files[j], + SharedDeps: shared, + Coupling: coupling, + }) + } + } + } + + sort.Slice(metrics, func(i, j int) bool { + return metrics[i].Coupling > metrics[j].Coupling + }) + + if topN > len(metrics) { + topN = len(metrics) + } + return metrics[:topN], nil +} + +// CrossRepoQuery queries across multiple codegraph databases. +// Useful for finding relationships between hawk, eyrie, tok, yaad, etc. +func CrossRepoQuery(repos []string, query string, limit int) (map[string][]Node, error) { + results := make(map[string][]Node) + + for _, repoRoot := range repos { + cg, err := Open(repoRoot) + if err != nil { + continue // Skip repos without codegraph + } + + nodes, err := cg.Search(query, limit) + cg.Close() + if err != nil || len(nodes) == 0 { + continue + } + + results[repoRoot] = nodes + } + + return results, nil +} + +// CrossRepoImpact finds the impact of changing a symbol across multiple repos. +// If a symbol in hawk calls a symbol in eyrie, this traces that cross-repo dependency. +func CrossRepoImpact(repos []string, symbol string, maxDepth int) (map[string]*ImpactResult, error) { + results := make(map[string]*ImpactResult) + + for _, repoRoot := range repos { + cg, err := Open(repoRoot) + if err != nil { + continue + } + + // Search for the symbol + nodes, err := cg.Search(symbol, 5) + if err != nil || len(nodes) == 0 { + cg.Close() + continue + } + + // Get impact for each matching symbol + for _, n := range nodes { + impact, err := cg.ImpactAnalysis(n.ID, maxDepth) + if err != nil { + continue + } + results[repoRoot+":"+n.Name] = impact + } + + cg.Close() + } + + return results, nil +} + +// FindCrossRepoCalls finds function calls that cross repo boundaries. +// For example, hawk calling eyrie functions. +func FindCrossRepoCalls(repos []string) ([]CrossRepoCall, error) { + type repoSymbol struct { + repo string + node Node + } + + // Build a map of all symbols across repos + allSymbols := make(map[string][]repoSymbol) // name -> [{repo, node}] + repoNodes := make(map[string]map[string]bool) + + for _, repoRoot := range repos { + cg, err := Open(repoRoot) + if err != nil { + continue + } + + repoNodes[repoRoot] = make(map[string]bool) + + // Get all nodes + rows, err := cg.db.QueryContext(context.Background(), "SELECT id, kind, name, qualified_name, file_path, language, start_line, end_line, signature, docstring, visibility, is_exported FROM nodes") + if err != nil { + cg.Close() + continue + } + + nodes, _ := scanNodes(rows) + rows.Close() + + for _, n := range nodes { + allSymbols[n.Name] = append(allSymbols[n.Name], repoSymbol{repoRoot, n}) + repoNodes[repoRoot][n.ID] = true + } + + cg.Close() + } + + // Find calls that reference symbols in other repos + var crossCalls []CrossRepoCall + + for _, repoRoot := range repos { + cg, err := Open(repoRoot) + if err != nil { + continue + } + + // Get unresolved refs (calls to symbols not in this repo) + rows, err := cg.db.QueryContext(context.Background(), "SELECT from_node_id, reference_name, file_path, line FROM unresolved_refs") + if err != nil { + cg.Close() + continue + } + + for rows.Next() { + var fromID, refName, filePath string + var line int + rows.Scan(&fromID, &refName, &filePath, &line) + + // Check if this reference exists in another repo + for _, target := range allSymbols[refName] { + if target.repo != repoRoot { + crossCalls = append(crossCalls, CrossRepoCall{ + FromRepo: repoRoot, + ToRepo: target.repo, + Symbol: refName, + File: filePath, + Line: line, + Target: target.node, + }) + } + } + } + + rows.Close() + cg.Close() + } + + return crossCalls, nil +} + +// CrossRepoCall represents a function call that crosses repo boundaries. +type CrossRepoCall struct { + FromRepo string `json:"from_repo"` + ToRepo string `json:"to_repo"` + Symbol string `json:"symbol"` + File string `json:"file"` + Line int `json:"line"` + Target Node `json:"target"` +} From 92def17d76fe8fcabcaafe150de6d8bf8663bc2e Mon Sep 17 00:00:00 2001 From: Lakshman Patel Date: Fri, 19 Jun 2026 20:34:41 +0530 Subject: [PATCH 17/20] test(fingerprint): split project_test.go into detect and conventions test files --- .../fingerprint/project_conventions_test.go | 145 ++++ .../fingerprint/project_detect_test.go | 583 ++++++++++++++ internal/feature/fingerprint/project_test.go | 713 ------------------ 3 files changed, 728 insertions(+), 713 deletions(-) create mode 100644 internal/feature/fingerprint/project_conventions_test.go create mode 100644 internal/feature/fingerprint/project_detect_test.go diff --git a/internal/feature/fingerprint/project_conventions_test.go b/internal/feature/fingerprint/project_conventions_test.go new file mode 100644 index 00000000..a75a75f8 --- /dev/null +++ b/internal/feature/fingerprint/project_conventions_test.go @@ -0,0 +1,145 @@ +package fingerprint + +import ( + "path/filepath" + "strings" + "testing" +) + +func TestDetectConventions_EditorConfig(t *testing.T) { + dir := t.TempDir() + writeTestFile(t, filepath.Join(dir, ".editorconfig"), `root = true + +[*] +indent_style = tab +indent_size = 4 +`) + + convs := detectConventions(dir, "Go") + + found := false + for _, c := range convs { + if c.Name == "indentation" { + found = true + if !strings.Contains(c.Description, "Tab") { + t.Errorf("expected tab indentation, got %q", c.Description) + } + if c.Confidence != 1.0 { + t.Errorf("expected confidence 1.0, got %f", c.Confidence) + } + } + } + if !found { + t.Error("expected indentation convention to be detected from .editorconfig") + } +} + +func TestDetectConventions_SpacesEditorConfig(t *testing.T) { + dir := t.TempDir() + writeTestFile(t, filepath.Join(dir, ".editorconfig"), `root = true + +[*] +indent_style = space +indent_size = 2 +`) + + convs := detectConventions(dir, "TypeScript") + + found := false + for _, c := range convs { + if c.Name == "indentation" { + found = true + if !strings.Contains(c.Description, "2-space") { + t.Errorf("expected 2-space indentation, got %q", c.Description) + } + } + } + if !found { + t.Error("expected indentation convention from .editorconfig") + } +} + +func TestDetectConventions_GoNaming(t *testing.T) { + dir := t.TempDir() + writeTestFile(t, filepath.Join(dir, "main.go"), "package main\n\nfunc main() {}\n") + + convs := detectConventions(dir, "Go") + + found := false + for _, c := range convs { + if c.Name == "naming" { + found = true + if !strings.Contains(c.Description, "camelCase") { + t.Errorf("expected camelCase/PascalCase for Go, got %q", c.Description) + } + } + } + if !found { + t.Error("expected naming convention for Go") + } +} + +func TestDetectConventions_GoErrorWrapping(t *testing.T) { + dir := t.TempDir() + content := `package main + +import "fmt" + +func foo() error { + err := bar() + if err != nil { + return fmt.Errorf("foo: %w", err) + } + err2 := baz() + if err2 != nil { + return fmt.Errorf("baz failed: %w", err2) + } + return nil +} +` + writeTestFile(t, filepath.Join(dir, "main.go"), content) + + convs := detectConventions(dir, "Go") + + found := false + for _, c := range convs { + if c.Name == "error-handling" { + found = true + if !strings.Contains(c.Description, "wrapping") { + t.Errorf("expected error wrapping convention, got %q", c.Description) + } + } + } + if !found { + t.Error("expected error-handling convention to be detected") + } +} + +func TestDetectConventions_PythonNaming(t *testing.T) { + dir := t.TempDir() + content := `def get_user_name(): + pass + +def calculate_total_price(): + pass + +def handle_request_error(): + pass +` + writeTestFile(t, filepath.Join(dir, "app.py"), content) + + convs := detectConventions(dir, "Python") + + found := false + for _, c := range convs { + if c.Name == "naming" { + found = true + if !strings.Contains(c.Description, "snake_case") { + t.Errorf("expected snake_case for Python, got %q", c.Description) + } + } + } + if !found { + t.Error("expected naming convention for Python") + } +} diff --git a/internal/feature/fingerprint/project_detect_test.go b/internal/feature/fingerprint/project_detect_test.go new file mode 100644 index 00000000..a232d3dd --- /dev/null +++ b/internal/feature/fingerprint/project_detect_test.go @@ -0,0 +1,583 @@ +package fingerprint + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestDetectLanguages_Extensions(t *testing.T) { + dir := t.TempDir() + + // Create files with different extensions. + writeTestFile(t, filepath.Join(dir, "main.go"), "package main\n\nfunc main() {}\n") + writeTestFile(t, filepath.Join(dir, "util.go"), "package main\n\nfunc util() {}\n") + writeTestFile(t, filepath.Join(dir, "handler.go"), "package main\n\nfunc handler() {}\n") + writeTestFile(t, filepath.Join(dir, "app.ts"), "const x = 1;\n") + writeTestFile(t, filepath.Join(dir, "style.css"), "body { color: red; }\n") + + langs := detectLanguages(dir) + + if len(langs) == 0 { + t.Fatal("expected at least one language") + } + + // Go should be first (3 files). + if langs[0].Name != "Go" { + t.Errorf("expected Go as primary language, got %q", langs[0].Name) + } + if langs[0].FileCount != 3 { + t.Errorf("expected 3 Go files, got %d", langs[0].FileCount) + } + + // Check percentage. + expectedPct := 3.0 / 5.0 * 100 + if langs[0].Percentage < expectedPct-0.1 || langs[0].Percentage > expectedPct+0.1 { + t.Errorf("expected ~%.1f%% for Go, got %.1f%%", expectedPct, langs[0].Percentage) + } + + // Check that all languages are detected. + langMap := make(map[string]int) + for _, l := range langs { + langMap[l.Name] = l.FileCount + } + if langMap["TypeScript"] != 1 { + t.Errorf("expected 1 TypeScript file, got %d", langMap["TypeScript"]) + } + if langMap["CSS"] != 1 { + t.Errorf("expected 1 CSS file, got %d", langMap["CSS"]) + } +} + +func TestDetectLanguages_SortedByFileCount(t *testing.T) { + dir := t.TempDir() + + // Create 5 Python files, 3 JS files, 1 Go file. + for i := 0; i < 5; i++ { + writeTestFile(t, filepath.Join(dir, strings.Repeat("a", i+1)+".py"), "# python\n") + } + for i := 0; i < 3; i++ { + writeTestFile(t, filepath.Join(dir, strings.Repeat("b", i+1)+".js"), "// js\n") + } + writeTestFile(t, filepath.Join(dir, "main.go"), "package main\n") + + langs := detectLanguages(dir) + + if len(langs) < 3 { + t.Fatalf("expected at least 3 languages, got %d", len(langs)) + } + + // Should be sorted: Python (5), JavaScript (3), Go (1). + if langs[0].Name != "Python" { + t.Errorf("expected Python first, got %q", langs[0].Name) + } + if langs[1].Name != "JavaScript" { + t.Errorf("expected JavaScript second, got %q", langs[1].Name) + } + if langs[2].Name != "Go" { + t.Errorf("expected Go third, got %q", langs[2].Name) + } +} + +func TestDetectFramework_GoMod(t *testing.T) { + tests := []struct { + name string + gomod string + expected string + }{ + { + name: "chi", + gomod: `module example.com/app + +go 1.21 + +require ( + github.com/go-chi/chi/v5 v5.0.10 +) +`, + expected: "chi", + }, + { + name: "gin", + gomod: `module example.com/app + +go 1.21 + +require ( + github.com/gin-gonic/gin v1.9.1 +) +`, + expected: "gin", + }, + { + name: "echo", + gomod: `module example.com/app + +go 1.21 + +require ( + github.com/labstack/echo/v4 v4.11.0 +) +`, + expected: "echo", + }, + { + name: "fiber", + gomod: `module example.com/app + +go 1.21 + +require ( + github.com/gofiber/fiber/v2 v2.50.0 +) +`, + expected: "fiber", + }, + { + name: "gorilla", + gomod: `module example.com/app + +go 1.21 + +require ( + github.com/gorilla/mux v1.8.0 +) +`, + expected: "gorilla", + }, + { + name: "no framework", + gomod: "module example.com/app\n\ngo 1.21\n", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + writeTestFile(t, filepath.Join(dir, "go.mod"), tt.gomod) + + result := detectFramework(dir, "Go") + if result != tt.expected { + t.Errorf("expected framework %q, got %q", tt.expected, result) + } + }) + } +} + +func TestDetectFramework_PackageJSON(t *testing.T) { + tests := []struct { + name string + pkg string + expected string + }{ + { + name: "next.js", + pkg: `{"dependencies": {"next": "^13.0.0", "react": "^18.0.0"}}`, + expected: "next.js", + }, + { + name: "express", + pkg: `{"dependencies": {"express": "^4.18.0"}}`, + expected: "express", + }, + { + name: "vue", + pkg: `{"dependencies": {"vue": "^3.0.0"}}`, + expected: "vue", + }, + { + name: "angular", + pkg: `{"dependencies": {"@angular/core": "^16.0.0"}}`, + expected: "angular", + }, + { + name: "react only", + pkg: `{"dependencies": {"react": "^18.0.0"}}`, + expected: "react", + }, + { + name: "no framework", + pkg: `{"dependencies": {"lodash": "^4.0.0"}}`, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + writeTestFile(t, filepath.Join(dir, "package.json"), tt.pkg) + + result := detectFramework(dir, "JavaScript") + if result != tt.expected { + t.Errorf("expected framework %q, got %q", tt.expected, result) + } + }) + } +} + +func TestDetectFramework_Python(t *testing.T) { + tests := []struct { + name string + requirements string + expected string + }{ + { + name: "django", + requirements: "django==4.2\ncelery==5.3\n", + expected: "django", + }, + { + name: "flask", + requirements: "flask==2.3\n", + expected: "flask", + }, + { + name: "fastapi", + requirements: "fastapi==0.100.0\nuvicorn==0.23.0\n", + expected: "fastapi", + }, + { + name: "no framework", + requirements: "requests==2.31\nnumpy==1.25\n", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + writeTestFile(t, filepath.Join(dir, "requirements.txt"), tt.requirements) + + result := detectFramework(dir, "Python") + if result != tt.expected { + t.Errorf("expected framework %q, got %q", tt.expected, result) + } + }) + } +} + +func TestDetectFramework_Rust(t *testing.T) { + tests := []struct { + name string + cargo string + expected string + }{ + { + name: "actix", + cargo: `[package] +name = "myapp" +version = "0.1.0" + +[dependencies] +actix-web = "4" +`, + expected: "actix", + }, + { + name: "axum", + cargo: `[package] +name = "myapp" +version = "0.1.0" + +[dependencies] +axum = "0.6" +tokio = { version = "1", features = ["full"] } +`, + expected: "axum", + }, + { + name: "rocket", + cargo: `[package] +name = "myapp" +version = "0.1.0" + +[dependencies] +rocket = "0.5" +`, + expected: "rocket", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + writeTestFile(t, filepath.Join(dir, "Cargo.toml"), tt.cargo) + + result := detectFramework(dir, "Rust") + if result != tt.expected { + t.Errorf("expected framework %q, got %q", tt.expected, result) + } + }) + } +} + +func TestDetectCISystem(t *testing.T) { + tests := []struct { + name string + setup func(dir string) + expected string + }{ + { + name: "github-actions", + setup: func(dir string) { + os.MkdirAll(filepath.Join(dir, ".github", "workflows"), 0o755) + }, + expected: "github-actions", + }, + { + name: "gitlab-ci", + setup: func(dir string) { + writeTestFile2(filepath.Join(dir, ".gitlab-ci.yml"), "stages:\n - test\n") + }, + expected: "gitlab-ci", + }, + { + name: "circleci", + setup: func(dir string) { + os.MkdirAll(filepath.Join(dir, ".circleci"), 0o755) + }, + expected: "circleci", + }, + { + name: "jenkins", + setup: func(dir string) { + writeTestFile2(filepath.Join(dir, "Jenkinsfile"), "pipeline {}\n") + }, + expected: "jenkins", + }, + { + name: "travis-ci", + setup: func(dir string) { + writeTestFile2(filepath.Join(dir, ".travis.yml"), "language: go\n") + }, + expected: "travis-ci", + }, + { + name: "no CI", + setup: func(dir string) {}, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + tt.setup(dir) + + result := detectCISystem(dir) + if result != tt.expected { + t.Errorf("expected CI %q, got %q", tt.expected, result) + } + }) + } +} + +func TestDetectBuildSystem(t *testing.T) { + tests := []struct { + name string + file string + content string + expected string + }{ + {"go modules", "go.mod", "module example.com/app\n\ngo 1.21\n", "go modules"}, + {"npm", "package.json", `{"name": "app"}`, "npm"}, + {"cargo", "Cargo.toml", "[package]\nname = \"app\"\n", "cargo"}, + {"maven", "pom.xml", "", "maven"}, + {"gradle", "build.gradle", "plugins { id 'java' }", "gradle"}, + {"cmake", "CMakeLists.txt", "cmake_minimum_required(VERSION 3.10)", "cmake"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + writeTestFile(t, filepath.Join(dir, tt.file), tt.content) + + result := detectBuildSystem(dir) + if result != tt.expected { + t.Errorf("expected build system %q, got %q", tt.expected, result) + } + }) + } +} + +func TestDetectTestFramework(t *testing.T) { + t.Run("go test", func(t *testing.T) { + dir := t.TempDir() + writeTestFile(t, filepath.Join(dir, "go.mod"), "module example.com/app\n\ngo 1.21\n") + writeTestFile(t, filepath.Join(dir, "main_test.go"), "package main\n\nimport \"testing\"\n\nfunc TestFoo(t *testing.T) {}\n") + + result := detectTestFramework(dir, "Go") + if result != "go test" { + t.Errorf("expected 'go test', got %q", result) + } + }) + + t.Run("go test + testify", func(t *testing.T) { + dir := t.TempDir() + writeTestFile(t, filepath.Join(dir, "go.mod"), "module example.com/app\n\ngo 1.21\n\nrequire (\n\tgithub.com/stretchr/testify v1.8.0\n)\n") + writeTestFile(t, filepath.Join(dir, "main_test.go"), "package main\n\nimport \"testing\"\n\nfunc TestFoo(t *testing.T) {}\n") + + result := detectTestFramework(dir, "Go") + if result != "go test + testify" { + t.Errorf("expected 'go test + testify', got %q", result) + } + }) + + t.Run("jest", func(t *testing.T) { + dir := t.TempDir() + writeTestFile(t, filepath.Join(dir, "package.json"), `{"devDependencies": {"jest": "^29.0.0"}}`) + + result := detectTestFramework(dir, "JavaScript") + if result != "jest" { + t.Errorf("expected 'jest', got %q", result) + } + }) + + t.Run("vitest", func(t *testing.T) { + dir := t.TempDir() + writeTestFile(t, filepath.Join(dir, "package.json"), `{"devDependencies": {"vitest": "^0.34.0"}}`) + + result := detectTestFramework(dir, "TypeScript") + if result != "vitest" { + t.Errorf("expected 'vitest', got %q", result) + } + }) + + t.Run("pytest", func(t *testing.T) { + dir := t.TempDir() + writeTestFile(t, filepath.Join(dir, "requirements.txt"), "pytest==7.4.0\nflask==2.3.0\n") + + result := detectTestFramework(dir, "Python") + if result != "pytest" { + t.Errorf("expected 'pytest', got %q", result) + } + }) + + t.Run("pytest from conftest", func(t *testing.T) { + dir := t.TempDir() + writeTestFile(t, filepath.Join(dir, "conftest.py"), "import pytest\n") + + result := detectTestFramework(dir, "Python") + if result != "pytest" { + t.Errorf("expected 'pytest', got %q", result) + } + }) + + t.Run("cargo test", func(t *testing.T) { + dir := t.TempDir() + writeTestFile(t, filepath.Join(dir, "Cargo.toml"), "[package]\nname = \"app\"\nversion = \"0.1.0\"\n") + + result := detectTestFramework(dir, "Rust") + if result != "cargo test" { + t.Errorf("expected 'cargo test', got %q", result) + } + }) + + t.Run("rspec", func(t *testing.T) { + dir := t.TempDir() + os.Mkdir(filepath.Join(dir, "spec"), 0o755) + writeTestFile(t, filepath.Join(dir, "Gemfile"), "gem 'rails'\ngem 'rspec'\n") + + result := detectTestFramework(dir, "Ruby") + if result != "rspec" { + t.Errorf("expected 'rspec', got %q", result) + } + }) +} + +func TestDetectDocker(t *testing.T) { + t.Run("with Dockerfile", func(t *testing.T) { + dir := t.TempDir() + writeTestFile(t, filepath.Join(dir, "Dockerfile"), "FROM golang:1.21\n") + + if !detectDocker(dir) { + t.Error("expected Docker=true with Dockerfile") + } + }) + + t.Run("with docker-compose", func(t *testing.T) { + dir := t.TempDir() + writeTestFile(t, filepath.Join(dir, "docker-compose.yml"), "version: '3'\n") + + if !detectDocker(dir) { + t.Error("expected Docker=true with docker-compose.yml") + } + }) + + t.Run("with compose.yaml", func(t *testing.T) { + dir := t.TempDir() + writeTestFile(t, filepath.Join(dir, "compose.yaml"), "services:\n") + + if !detectDocker(dir) { + t.Error("expected Docker=true with compose.yaml") + } + }) + + t.Run("no docker", func(t *testing.T) { + dir := t.TempDir() + + if detectDocker(dir) { + t.Error("expected Docker=false for empty dir") + } + }) +} + +func TestDetectMonorepo_MultipleGoMod(t *testing.T) { + dir := t.TempDir() + os.MkdirAll(filepath.Join(dir, "pkg1"), 0o755) + os.MkdirAll(filepath.Join(dir, "pkg2"), 0o755) + writeTestFile(t, filepath.Join(dir, "pkg1", "go.mod"), "module example.com/pkg1\n") + writeTestFile(t, filepath.Join(dir, "pkg2", "go.mod"), "module example.com/pkg2\n") + + if !detectMonorepo(dir) { + t.Error("expected Monorepo=true with multiple go.mod files") + } +} + +func TestDetectMonorepo_GoWork(t *testing.T) { + dir := t.TempDir() + writeTestFile(t, filepath.Join(dir, "go.work"), "go 1.21\n\nuse (\n\t./pkg1\n\t./pkg2\n)\n") + + if !detectMonorepo(dir) { + t.Error("expected Monorepo=true with go.work") + } +} + +func TestDetectMonorepo_PackagesDir(t *testing.T) { + dir := t.TempDir() + os.MkdirAll(filepath.Join(dir, "packages"), 0o755) + + if !detectMonorepo(dir) { + t.Error("expected Monorepo=true with packages/ directory") + } +} + +func TestDetectMonorepo_Workspaces(t *testing.T) { + dir := t.TempDir() + writeTestFile(t, filepath.Join(dir, "package.json"), `{"name": "monorepo", "workspaces": ["packages/*"]}`) + + if !detectMonorepo(dir) { + t.Error("expected Monorepo=true with workspaces in package.json") + } +} + +func TestDetectMonorepo_LernaJSON(t *testing.T) { + dir := t.TempDir() + writeTestFile(t, filepath.Join(dir, "lerna.json"), `{"version": "1.0.0"}`) + + if !detectMonorepo(dir) { + t.Error("expected Monorepo=true with lerna.json") + } +} + +func TestDetectMonorepo_NotMonorepo(t *testing.T) { + dir := t.TempDir() + writeTestFile(t, filepath.Join(dir, "go.mod"), "module example.com/app\n") + writeTestFile(t, filepath.Join(dir, "main.go"), "package main\n") + + if detectMonorepo(dir) { + t.Error("expected Monorepo=false for simple project") + } +} diff --git a/internal/feature/fingerprint/project_test.go b/internal/feature/fingerprint/project_test.go index cd236655..1dddfa9c 100644 --- a/internal/feature/fingerprint/project_test.go +++ b/internal/feature/fingerprint/project_test.go @@ -41,719 +41,6 @@ func TestScan_InvalidPath(t *testing.T) { } } -func TestDetectLanguages_Extensions(t *testing.T) { - dir := t.TempDir() - - // Create files with different extensions. - writeTestFile(t, filepath.Join(dir, "main.go"), "package main\n\nfunc main() {}\n") - writeTestFile(t, filepath.Join(dir, "util.go"), "package main\n\nfunc util() {}\n") - writeTestFile(t, filepath.Join(dir, "handler.go"), "package main\n\nfunc handler() {}\n") - writeTestFile(t, filepath.Join(dir, "app.ts"), "const x = 1;\n") - writeTestFile(t, filepath.Join(dir, "style.css"), "body { color: red; }\n") - - langs := detectLanguages(dir) - - if len(langs) == 0 { - t.Fatal("expected at least one language") - } - - // Go should be first (3 files). - if langs[0].Name != "Go" { - t.Errorf("expected Go as primary language, got %q", langs[0].Name) - } - if langs[0].FileCount != 3 { - t.Errorf("expected 3 Go files, got %d", langs[0].FileCount) - } - - // Check percentage. - expectedPct := 3.0 / 5.0 * 100 - if langs[0].Percentage < expectedPct-0.1 || langs[0].Percentage > expectedPct+0.1 { - t.Errorf("expected ~%.1f%% for Go, got %.1f%%", expectedPct, langs[0].Percentage) - } - - // Check that all languages are detected. - langMap := make(map[string]int) - for _, l := range langs { - langMap[l.Name] = l.FileCount - } - if langMap["TypeScript"] != 1 { - t.Errorf("expected 1 TypeScript file, got %d", langMap["TypeScript"]) - } - if langMap["CSS"] != 1 { - t.Errorf("expected 1 CSS file, got %d", langMap["CSS"]) - } -} - -func TestDetectLanguages_SortedByFileCount(t *testing.T) { - dir := t.TempDir() - - // Create 5 Python files, 3 JS files, 1 Go file. - for i := 0; i < 5; i++ { - writeTestFile(t, filepath.Join(dir, strings.Repeat("a", i+1)+".py"), "# python\n") - } - for i := 0; i < 3; i++ { - writeTestFile(t, filepath.Join(dir, strings.Repeat("b", i+1)+".js"), "// js\n") - } - writeTestFile(t, filepath.Join(dir, "main.go"), "package main\n") - - langs := detectLanguages(dir) - - if len(langs) < 3 { - t.Fatalf("expected at least 3 languages, got %d", len(langs)) - } - - // Should be sorted: Python (5), JavaScript (3), Go (1). - if langs[0].Name != "Python" { - t.Errorf("expected Python first, got %q", langs[0].Name) - } - if langs[1].Name != "JavaScript" { - t.Errorf("expected JavaScript second, got %q", langs[1].Name) - } - if langs[2].Name != "Go" { - t.Errorf("expected Go third, got %q", langs[2].Name) - } -} - -func TestDetectFramework_GoMod(t *testing.T) { - tests := []struct { - name string - gomod string - expected string - }{ - { - name: "chi", - gomod: `module example.com/app - -go 1.21 - -require ( - github.com/go-chi/chi/v5 v5.0.10 -) -`, - expected: "chi", - }, - { - name: "gin", - gomod: `module example.com/app - -go 1.21 - -require ( - github.com/gin-gonic/gin v1.9.1 -) -`, - expected: "gin", - }, - { - name: "echo", - gomod: `module example.com/app - -go 1.21 - -require ( - github.com/labstack/echo/v4 v4.11.0 -) -`, - expected: "echo", - }, - { - name: "fiber", - gomod: `module example.com/app - -go 1.21 - -require ( - github.com/gofiber/fiber/v2 v2.50.0 -) -`, - expected: "fiber", - }, - { - name: "gorilla", - gomod: `module example.com/app - -go 1.21 - -require ( - github.com/gorilla/mux v1.8.0 -) -`, - expected: "gorilla", - }, - { - name: "no framework", - gomod: "module example.com/app\n\ngo 1.21\n", - expected: "", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - dir := t.TempDir() - writeTestFile(t, filepath.Join(dir, "go.mod"), tt.gomod) - - result := detectFramework(dir, "Go") - if result != tt.expected { - t.Errorf("expected framework %q, got %q", tt.expected, result) - } - }) - } -} - -func TestDetectFramework_PackageJSON(t *testing.T) { - tests := []struct { - name string - pkg string - expected string - }{ - { - name: "next.js", - pkg: `{"dependencies": {"next": "^13.0.0", "react": "^18.0.0"}}`, - expected: "next.js", - }, - { - name: "express", - pkg: `{"dependencies": {"express": "^4.18.0"}}`, - expected: "express", - }, - { - name: "vue", - pkg: `{"dependencies": {"vue": "^3.0.0"}}`, - expected: "vue", - }, - { - name: "angular", - pkg: `{"dependencies": {"@angular/core": "^16.0.0"}}`, - expected: "angular", - }, - { - name: "react only", - pkg: `{"dependencies": {"react": "^18.0.0"}}`, - expected: "react", - }, - { - name: "no framework", - pkg: `{"dependencies": {"lodash": "^4.0.0"}}`, - expected: "", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - dir := t.TempDir() - writeTestFile(t, filepath.Join(dir, "package.json"), tt.pkg) - - result := detectFramework(dir, "JavaScript") - if result != tt.expected { - t.Errorf("expected framework %q, got %q", tt.expected, result) - } - }) - } -} - -func TestDetectFramework_Python(t *testing.T) { - tests := []struct { - name string - requirements string - expected string - }{ - { - name: "django", - requirements: "django==4.2\ncelery==5.3\n", - expected: "django", - }, - { - name: "flask", - requirements: "flask==2.3\n", - expected: "flask", - }, - { - name: "fastapi", - requirements: "fastapi==0.100.0\nuvicorn==0.23.0\n", - expected: "fastapi", - }, - { - name: "no framework", - requirements: "requests==2.31\nnumpy==1.25\n", - expected: "", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - dir := t.TempDir() - writeTestFile(t, filepath.Join(dir, "requirements.txt"), tt.requirements) - - result := detectFramework(dir, "Python") - if result != tt.expected { - t.Errorf("expected framework %q, got %q", tt.expected, result) - } - }) - } -} - -func TestDetectFramework_Rust(t *testing.T) { - tests := []struct { - name string - cargo string - expected string - }{ - { - name: "actix", - cargo: `[package] -name = "myapp" -version = "0.1.0" - -[dependencies] -actix-web = "4" -`, - expected: "actix", - }, - { - name: "axum", - cargo: `[package] -name = "myapp" -version = "0.1.0" - -[dependencies] -axum = "0.6" -tokio = { version = "1", features = ["full"] } -`, - expected: "axum", - }, - { - name: "rocket", - cargo: `[package] -name = "myapp" -version = "0.1.0" - -[dependencies] -rocket = "0.5" -`, - expected: "rocket", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - dir := t.TempDir() - writeTestFile(t, filepath.Join(dir, "Cargo.toml"), tt.cargo) - - result := detectFramework(dir, "Rust") - if result != tt.expected { - t.Errorf("expected framework %q, got %q", tt.expected, result) - } - }) - } -} - -func TestDetectCISystem(t *testing.T) { - tests := []struct { - name string - setup func(dir string) - expected string - }{ - { - name: "github-actions", - setup: func(dir string) { - os.MkdirAll(filepath.Join(dir, ".github", "workflows"), 0o755) - }, - expected: "github-actions", - }, - { - name: "gitlab-ci", - setup: func(dir string) { - writeTestFile2(filepath.Join(dir, ".gitlab-ci.yml"), "stages:\n - test\n") - }, - expected: "gitlab-ci", - }, - { - name: "circleci", - setup: func(dir string) { - os.MkdirAll(filepath.Join(dir, ".circleci"), 0o755) - }, - expected: "circleci", - }, - { - name: "jenkins", - setup: func(dir string) { - writeTestFile2(filepath.Join(dir, "Jenkinsfile"), "pipeline {}\n") - }, - expected: "jenkins", - }, - { - name: "travis-ci", - setup: func(dir string) { - writeTestFile2(filepath.Join(dir, ".travis.yml"), "language: go\n") - }, - expected: "travis-ci", - }, - { - name: "no CI", - setup: func(dir string) {}, - expected: "", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - dir := t.TempDir() - tt.setup(dir) - - result := detectCISystem(dir) - if result != tt.expected { - t.Errorf("expected CI %q, got %q", tt.expected, result) - } - }) - } -} - -func TestDetectBuildSystem(t *testing.T) { - tests := []struct { - name string - file string - content string - expected string - }{ - {"go modules", "go.mod", "module example.com/app\n\ngo 1.21\n", "go modules"}, - {"npm", "package.json", `{"name": "app"}`, "npm"}, - {"cargo", "Cargo.toml", "[package]\nname = \"app\"\n", "cargo"}, - {"maven", "pom.xml", "", "maven"}, - {"gradle", "build.gradle", "plugins { id 'java' }", "gradle"}, - {"cmake", "CMakeLists.txt", "cmake_minimum_required(VERSION 3.10)", "cmake"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - dir := t.TempDir() - writeTestFile(t, filepath.Join(dir, tt.file), tt.content) - - result := detectBuildSystem(dir) - if result != tt.expected { - t.Errorf("expected build system %q, got %q", tt.expected, result) - } - }) - } -} - -func TestDetectTestFramework(t *testing.T) { - t.Run("go test", func(t *testing.T) { - dir := t.TempDir() - writeTestFile(t, filepath.Join(dir, "go.mod"), "module example.com/app\n\ngo 1.21\n") - writeTestFile(t, filepath.Join(dir, "main_test.go"), "package main\n\nimport \"testing\"\n\nfunc TestFoo(t *testing.T) {}\n") - - result := detectTestFramework(dir, "Go") - if result != "go test" { - t.Errorf("expected 'go test', got %q", result) - } - }) - - t.Run("go test + testify", func(t *testing.T) { - dir := t.TempDir() - writeTestFile(t, filepath.Join(dir, "go.mod"), "module example.com/app\n\ngo 1.21\n\nrequire (\n\tgithub.com/stretchr/testify v1.8.0\n)\n") - writeTestFile(t, filepath.Join(dir, "main_test.go"), "package main\n\nimport \"testing\"\n\nfunc TestFoo(t *testing.T) {}\n") - - result := detectTestFramework(dir, "Go") - if result != "go test + testify" { - t.Errorf("expected 'go test + testify', got %q", result) - } - }) - - t.Run("jest", func(t *testing.T) { - dir := t.TempDir() - writeTestFile(t, filepath.Join(dir, "package.json"), `{"devDependencies": {"jest": "^29.0.0"}}`) - - result := detectTestFramework(dir, "JavaScript") - if result != "jest" { - t.Errorf("expected 'jest', got %q", result) - } - }) - - t.Run("vitest", func(t *testing.T) { - dir := t.TempDir() - writeTestFile(t, filepath.Join(dir, "package.json"), `{"devDependencies": {"vitest": "^0.34.0"}}`) - - result := detectTestFramework(dir, "TypeScript") - if result != "vitest" { - t.Errorf("expected 'vitest', got %q", result) - } - }) - - t.Run("pytest", func(t *testing.T) { - dir := t.TempDir() - writeTestFile(t, filepath.Join(dir, "requirements.txt"), "pytest==7.4.0\nflask==2.3.0\n") - - result := detectTestFramework(dir, "Python") - if result != "pytest" { - t.Errorf("expected 'pytest', got %q", result) - } - }) - - t.Run("pytest from conftest", func(t *testing.T) { - dir := t.TempDir() - writeTestFile(t, filepath.Join(dir, "conftest.py"), "import pytest\n") - - result := detectTestFramework(dir, "Python") - if result != "pytest" { - t.Errorf("expected 'pytest', got %q", result) - } - }) - - t.Run("cargo test", func(t *testing.T) { - dir := t.TempDir() - writeTestFile(t, filepath.Join(dir, "Cargo.toml"), "[package]\nname = \"app\"\nversion = \"0.1.0\"\n") - - result := detectTestFramework(dir, "Rust") - if result != "cargo test" { - t.Errorf("expected 'cargo test', got %q", result) - } - }) - - t.Run("rspec", func(t *testing.T) { - dir := t.TempDir() - os.Mkdir(filepath.Join(dir, "spec"), 0o755) - writeTestFile(t, filepath.Join(dir, "Gemfile"), "gem 'rails'\ngem 'rspec'\n") - - result := detectTestFramework(dir, "Ruby") - if result != "rspec" { - t.Errorf("expected 'rspec', got %q", result) - } - }) -} - -func TestDetectDocker(t *testing.T) { - t.Run("with Dockerfile", func(t *testing.T) { - dir := t.TempDir() - writeTestFile(t, filepath.Join(dir, "Dockerfile"), "FROM golang:1.21\n") - - if !detectDocker(dir) { - t.Error("expected Docker=true with Dockerfile") - } - }) - - t.Run("with docker-compose", func(t *testing.T) { - dir := t.TempDir() - writeTestFile(t, filepath.Join(dir, "docker-compose.yml"), "version: '3'\n") - - if !detectDocker(dir) { - t.Error("expected Docker=true with docker-compose.yml") - } - }) - - t.Run("with compose.yaml", func(t *testing.T) { - dir := t.TempDir() - writeTestFile(t, filepath.Join(dir, "compose.yaml"), "services:\n") - - if !detectDocker(dir) { - t.Error("expected Docker=true with compose.yaml") - } - }) - - t.Run("no docker", func(t *testing.T) { - dir := t.TempDir() - - if detectDocker(dir) { - t.Error("expected Docker=false for empty dir") - } - }) -} - -func TestDetectMonorepo_MultipleGoMod(t *testing.T) { - dir := t.TempDir() - os.MkdirAll(filepath.Join(dir, "pkg1"), 0o755) - os.MkdirAll(filepath.Join(dir, "pkg2"), 0o755) - writeTestFile(t, filepath.Join(dir, "pkg1", "go.mod"), "module example.com/pkg1\n") - writeTestFile(t, filepath.Join(dir, "pkg2", "go.mod"), "module example.com/pkg2\n") - - if !detectMonorepo(dir) { - t.Error("expected Monorepo=true with multiple go.mod files") - } -} - -func TestDetectMonorepo_GoWork(t *testing.T) { - dir := t.TempDir() - writeTestFile(t, filepath.Join(dir, "go.work"), "go 1.21\n\nuse (\n\t./pkg1\n\t./pkg2\n)\n") - - if !detectMonorepo(dir) { - t.Error("expected Monorepo=true with go.work") - } -} - -func TestDetectMonorepo_PackagesDir(t *testing.T) { - dir := t.TempDir() - os.MkdirAll(filepath.Join(dir, "packages"), 0o755) - - if !detectMonorepo(dir) { - t.Error("expected Monorepo=true with packages/ directory") - } -} - -func TestDetectMonorepo_Workspaces(t *testing.T) { - dir := t.TempDir() - writeTestFile(t, filepath.Join(dir, "package.json"), `{"name": "monorepo", "workspaces": ["packages/*"]}`) - - if !detectMonorepo(dir) { - t.Error("expected Monorepo=true with workspaces in package.json") - } -} - -func TestDetectMonorepo_LernaJSON(t *testing.T) { - dir := t.TempDir() - writeTestFile(t, filepath.Join(dir, "lerna.json"), `{"version": "1.0.0"}`) - - if !detectMonorepo(dir) { - t.Error("expected Monorepo=true with lerna.json") - } -} - -func TestDetectMonorepo_NotMonorepo(t *testing.T) { - dir := t.TempDir() - writeTestFile(t, filepath.Join(dir, "go.mod"), "module example.com/app\n") - writeTestFile(t, filepath.Join(dir, "main.go"), "package main\n") - - if detectMonorepo(dir) { - t.Error("expected Monorepo=false for simple project") - } -} - -func TestDetectConventions_EditorConfig(t *testing.T) { - dir := t.TempDir() - writeTestFile(t, filepath.Join(dir, ".editorconfig"), `root = true - -[*] -indent_style = tab -indent_size = 4 -`) - - convs := detectConventions(dir, "Go") - - found := false - for _, c := range convs { - if c.Name == "indentation" { - found = true - if !strings.Contains(c.Description, "Tab") { - t.Errorf("expected tab indentation, got %q", c.Description) - } - if c.Confidence != 1.0 { - t.Errorf("expected confidence 1.0, got %f", c.Confidence) - } - } - } - if !found { - t.Error("expected indentation convention to be detected from .editorconfig") - } -} - -func TestDetectConventions_SpacesEditorConfig(t *testing.T) { - dir := t.TempDir() - writeTestFile(t, filepath.Join(dir, ".editorconfig"), `root = true - -[*] -indent_style = space -indent_size = 2 -`) - - convs := detectConventions(dir, "TypeScript") - - found := false - for _, c := range convs { - if c.Name == "indentation" { - found = true - if !strings.Contains(c.Description, "2-space") { - t.Errorf("expected 2-space indentation, got %q", c.Description) - } - } - } - if !found { - t.Error("expected indentation convention from .editorconfig") - } -} - -func TestDetectConventions_GoNaming(t *testing.T) { - dir := t.TempDir() - writeTestFile(t, filepath.Join(dir, "main.go"), "package main\n\nfunc main() {}\n") - - convs := detectConventions(dir, "Go") - - found := false - for _, c := range convs { - if c.Name == "naming" { - found = true - if !strings.Contains(c.Description, "camelCase") { - t.Errorf("expected camelCase/PascalCase for Go, got %q", c.Description) - } - } - } - if !found { - t.Error("expected naming convention for Go") - } -} - -func TestDetectConventions_GoErrorWrapping(t *testing.T) { - dir := t.TempDir() - content := `package main - -import "fmt" - -func foo() error { - err := bar() - if err != nil { - return fmt.Errorf("foo: %w", err) - } - err2 := baz() - if err2 != nil { - return fmt.Errorf("baz failed: %w", err2) - } - return nil -} -` - writeTestFile(t, filepath.Join(dir, "main.go"), content) - - convs := detectConventions(dir, "Go") - - found := false - for _, c := range convs { - if c.Name == "error-handling" { - found = true - if !strings.Contains(c.Description, "wrapping") { - t.Errorf("expected error wrapping convention, got %q", c.Description) - } - } - } - if !found { - t.Error("expected error-handling convention to be detected") - } -} - -func TestDetectConventions_PythonNaming(t *testing.T) { - dir := t.TempDir() - content := `def get_user_name(): - pass - -def calculate_total_price(): - pass - -def handle_request_error(): - pass -` - writeTestFile(t, filepath.Join(dir, "app.py"), content) - - convs := detectConventions(dir, "Python") - - found := false - for _, c := range convs { - if c.Name == "naming" { - found = true - if !strings.Contains(c.Description, "snake_case") { - t.Errorf("expected snake_case for Python, got %q", c.Description) - } - } - } - if !found { - t.Error("expected naming convention for Python") - } -} - func TestGenerateRecommendations_Go(t *testing.T) { fp := &ProjectFingerprint{ Language: "Go", From cabafb8a7f50f214e98e59d3fc86bc4c58c6683a Mon Sep 17 00:00:00 2001 From: Lakshman Patel Date: Fri, 19 Jun 2026 20:38:04 +0530 Subject: [PATCH 18/20] test(agents): split persona registry/selection tests into persona_registry_test.go --- .../agents/persona_registry_test.go | 668 ++++++++++++++++++ internal/multiagent/agents/persona_test.go | 660 ----------------- 2 files changed, 668 insertions(+), 660 deletions(-) create mode 100644 internal/multiagent/agents/persona_registry_test.go diff --git a/internal/multiagent/agents/persona_registry_test.go b/internal/multiagent/agents/persona_registry_test.go new file mode 100644 index 00000000..052c86a0 --- /dev/null +++ b/internal/multiagent/agents/persona_registry_test.go @@ -0,0 +1,668 @@ +package agents + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestSelectPersona_Security(t *testing.T) { + r := NewPersonaRegistry(t.TempDir()) + r.Personas["security"] = &Persona{ + Name: "security", + Expertise: []string{"security"}, + } + r.Personas["tester"] = &Persona{ + Name: "tester", + Expertise: []string{"testing"}, + } + r.Personas["devops"] = &Persona{ + Name: "devops", + Expertise: []string{"devops"}, + } + + // Security task + p := r.SelectPersona("fix security vulnerability in auth handler") + if p == nil || p.Name != "security" { + t.Errorf("expected security persona, got %v", p) + } +} + +func TestSelectPersona_Testing(t *testing.T) { + r := NewPersonaRegistry(t.TempDir()) + r.Personas["security"] = &Persona{ + Name: "security", + Expertise: []string{"security"}, + } + r.Personas["tester"] = &Persona{ + Name: "tester", + Expertise: []string{"testing"}, + } + + // Testing task + p := r.SelectPersona("write unit tests for the parser") + if p == nil || p.Name != "tester" { + t.Errorf("expected tester persona, got %v", p) + } +} + +func TestSelectPersona_DevOps(t *testing.T) { + r := NewPersonaRegistry(t.TempDir()) + r.Personas["devops"] = &Persona{ + Name: "devops", + Expertise: []string{"devops"}, + } + r.Personas["backend"] = &Persona{ + Name: "backend", + Expertise: []string{"backend"}, + } + + // DevOps task + p := r.SelectPersona("deploy to kubernetes cluster") + if p == nil || p.Name != "devops" { + t.Errorf("expected devops persona, got %v", p) + } +} + +func TestSelectPersona_FallsBackToDefault(t *testing.T) { + r := NewPersonaRegistry(t.TempDir()) + r.Personas["default"] = &Persona{ + Name: "default", + Expertise: []string{}, + } + r.Personas["security"] = &Persona{ + Name: "security", + Expertise: []string{"security"}, + } + + // No keyword match + p := r.SelectPersona("do something random and unrelated") + if p == nil || p.Name != "default" { + t.Errorf("expected default persona as fallback, got %v", p) + } +} + +func TestSelectPersona_NoMatch(t *testing.T) { + r := NewPersonaRegistry(t.TempDir()) + r.Personas["security"] = &Persona{ + Name: "security", + Expertise: []string{"security"}, + } + + // No match and no default + p := r.SelectPersona("play some music") + if p != nil { + t.Errorf("expected nil when no match and no default, got %v", p) + } +} + +func TestBuildSystemPrompt_IncludesAllComponents(t *testing.T) { + p := &Persona{ + Name: "test", + SystemPrompt: "You are a test assistant.", + Expertise: []string{"backend", "testing"}, + CommunicationStyle: "concise", + Rules: []string{"Rule one", "Rule two"}, + Examples: []PersonaExample{ + { + Input: "example input", + Output: "example output", + Context: "example context", + }, + }, + } + + result := BuildSystemPrompt(p, "This is a Go project using REST APIs.") + + // Should contain system prompt + if !strings.Contains(result, "You are a test assistant.") { + t.Error("should contain system prompt") + } + + // Should contain expertise + if !strings.Contains(result, "backend, testing") { + t.Error("should contain expertise") + } + + // Should contain communication style + if !strings.Contains(result, "brief and to the point") { + t.Error("should contain communication style for 'concise'") + } + + // Should contain rules + if !strings.Contains(result, "- Rule one") { + t.Error("should contain rules") + } + if !strings.Contains(result, "- Rule two") { + t.Error("should contain rule two") + } + + // Should contain examples + if !strings.Contains(result, "example input") { + t.Error("should contain example input") + } + if !strings.Contains(result, "example output") { + t.Error("should contain example output") + } + if !strings.Contains(result, "example context") { + t.Error("should contain example context") + } + + // Should contain project context + if !strings.Contains(result, "This is a Go project using REST APIs.") { + t.Error("should contain project context") + } +} + +func TestBuildSystemPrompt_EmptyPersona(t *testing.T) { + p := &Persona{Name: "empty"} + result := BuildSystemPrompt(p, "") + if result != "" { + t.Errorf("expected empty prompt for empty persona, got %q", result) + } +} + +func TestCreateGetDelete_Lifecycle(t *testing.T) { + dir := t.TempDir() + r := NewPersonaRegistry(dir) + + // Create + p := &Persona{ + Name: "lifecycle-test", + Description: "Test persona for lifecycle", + Model: "claude-sonnet-4-6", + Expertise: []string{"testing"}, + SystemPrompt: "You are a lifecycle test.", + } + if err := r.Create(p); err != nil { + t.Fatalf("Create failed: %v", err) + } + + // Verify file exists + path := filepath.Join(dir, "lifecycle-test.md") + if _, err := os.Stat(path); err != nil { + t.Fatalf("persona file not created: %v", err) + } + + // Get + got, err := r.Get("lifecycle-test") + if err != nil { + t.Fatalf("Get failed: %v", err) + } + if got.Name != "lifecycle-test" { + t.Errorf("expected name 'lifecycle-test', got %q", got.Name) + } + if got.Model != "claude-sonnet-4-6" { + t.Errorf("expected model 'claude-sonnet-4-6', got %q", got.Model) + } + + // Delete + if err := r.Delete("lifecycle-test"); err != nil { + t.Fatalf("Delete failed: %v", err) + } + + // Verify gone + if _, err := r.Get("lifecycle-test"); err == nil { + t.Error("expected error after delete, got nil") + } + if _, err := os.Stat(path); !os.IsNotExist(err) { + t.Error("persona file should be deleted from disk") + } +} + +func TestCreate_EmptyName(t *testing.T) { + dir := t.TempDir() + r := NewPersonaRegistry(dir) + + err := r.Create(&Persona{}) + if err == nil { + t.Error("expected error for empty name") + } +} + +func TestDelete_NotFound(t *testing.T) { + dir := t.TempDir() + r := NewPersonaRegistry(dir) + + err := r.Delete("nonexistent") + if err == nil { + t.Error("expected error deleting nonexistent persona") + } +} + +func TestList_ReturnsAllPersonas(t *testing.T) { + dir := t.TempDir() + r := NewPersonaRegistry(dir) + + r.Personas["alpha"] = &Persona{Name: "alpha"} + r.Personas["beta"] = &Persona{Name: "beta"} + r.Personas["gamma"] = &Persona{Name: "gamma"} + + list := r.List() + if len(list) != 3 { + t.Fatalf("expected 3 personas, got %d", len(list)) + } + + // Should be sorted + if list[0].Name != "alpha" || list[1].Name != "beta" || list[2].Name != "gamma" { + t.Errorf("expected sorted order, got %s, %s, %s", list[0].Name, list[1].Name, list[2].Name) + } +} + +func TestList_Empty(t *testing.T) { + dir := t.TempDir() + r := NewPersonaRegistry(dir) + + list := r.List() + if len(list) != 0 { + t.Errorf("expected empty list, got %d", len(list)) + } +} + +func TestBuiltinPersonas_AreValid(t *testing.T) { + builtins := BuiltinPersonas() + + expectedNames := map[string]bool{ + "default": false, + "reviewer": false, + "architect": false, + "debugger": false, + "teacher": false, + "speed": false, + "planner": false, + "executor": false, + "critic": false, + "security-reviewer": false, + "test-engineer": false, + "tracer": false, + "verifier": false, + "validator": false, + "integrator": false, + "documenter": false, + "devops": false, + "performance": false, + "refactorer": false, + "cavecrew-investigator": false, + "cavecrew-builder": false, + "cavecrew-reviewer": false, + } + + for _, p := range builtins { + if p.Name == "" { + t.Error("built-in persona has empty name") + } + if p.Description == "" { + t.Errorf("built-in persona %q has empty description", p.Name) + } + if p.SystemPrompt == "" { + t.Errorf("built-in persona %q has empty system prompt", p.Name) + } + if len(p.Expertise) == 0 { + t.Errorf("built-in persona %q has no expertise", p.Name) + } + if p.CommunicationStyle == "" { + t.Errorf("built-in persona %q has no communication style", p.Name) + } + if p.CreatedAt.IsZero() { + t.Errorf("built-in persona %q has zero CreatedAt", p.Name) + } + if _, ok := expectedNames[p.Name]; ok { + expectedNames[p.Name] = true + } else { + t.Errorf("unexpected built-in persona: %q", p.Name) + } + } + + for name, found := range expectedNames { + if !found { + t.Errorf("expected built-in persona %q not found", name) + } + } +} + +func TestSelectPersona_NewDomains(t *testing.T) { + r := NewPersonaRegistry(t.TempDir()) + for _, p := range BuiltinPersonas() { + r.Personas[p.Name] = p + } + + cases := []struct { + task string + wantDomain string // expertise the selected persona should include + }{ + {"profile and optimize this slow benchmark with high latency", "performance"}, + {"refactor this module to reduce technical debt and simplify", "refactoring"}, + {"write the readme and api docs with a tutorial guide", "documentation"}, + {"add observability: trace spans and structured logging", "tracing"}, + } + + for _, c := range cases { + p := r.SelectPersona(c.task) + if p == nil { + t.Errorf("task %q selected nil persona", c.task) + continue + } + found := false + for _, e := range p.Expertise { + if e == c.wantDomain { + found = true + } + } + if !found { + t.Errorf("task %q selected %q (expertise %v), expected domain %q", + c.task, p.Name, p.Expertise, c.wantDomain) + } + } +} + +func TestBuiltinPersonas_Count(t *testing.T) { + if got := len(BuiltinPersonas()); got != 22 { + t.Errorf("expected 22 built-in personas, got %d", got) + } +} + +func TestCavecrewPersonas_ReturnsThree(t *testing.T) { + crew := CavecrewPersonas() + if len(crew) != 3 { + t.Fatalf("expected 3 cavecrew personas, got %d", len(crew)) + } + want := []string{"cavecrew-investigator", "cavecrew-builder", "cavecrew-reviewer"} + for i, p := range crew { + if p.Name != want[i] { + t.Errorf("expected %d-th persona %q, got %q", i, want[i], p.Name) + } + if p.Description == "" { + t.Errorf("cavecrew persona %q has empty description", p.Name) + } + if p.SystemPrompt == "" { + t.Errorf("cavecrew persona %q has empty system prompt", p.Name) + } + if len(p.Rules) == 0 { + t.Errorf("cavecrew persona %q has no rules", p.Name) + } + } +} + +func TestCavecrewPersonas_AreInBuiltinList(t *testing.T) { + // Cavecrew personas must be a subset of BuiltinPersonas so + // EnsureBuiltins auto-creates them on first run. + builtins := map[string]bool{} + for _, p := range BuiltinPersonas() { + builtins[p.Name] = true + } + for _, p := range CavecrewPersonas() { + if !builtins[p.Name] { + t.Errorf("cavecrew persona %q missing from BuiltinPersonas", p.Name) + } + } +} + +func TestEnsureCavecrew_WritesFiles(t *testing.T) { + dir := t.TempDir() + r := NewPersonaRegistry(dir) + if err := r.EnsureCavecrew(); err != nil { + t.Fatalf("EnsureCavecrew: %v", err) + } + for _, want := range []string{"cavecrew-investigator.md", "cavecrew-builder.md", "cavecrew-reviewer.md"} { + path := filepath.Join(dir, want) + if _, err := os.Stat(path); err != nil { + t.Errorf("expected file %s: %v", want, err) + } + } +} + +func TestLoadAll_FromDirectory(t *testing.T) { + dir := t.TempDir() + + // Write some persona files + file1 := `--- +name: persona-one +description: First persona +expertise: [backend] +style: concise +temperature: 0.3 +--- +You are persona one. +` + file2 := `--- +name: persona-two +description: Second persona +expertise: [frontend] +style: detailed +temperature: 0.7 +--- +You are persona two. +` + if err := os.WriteFile(filepath.Join(dir, "persona-one.md"), []byte(file1), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(dir, "persona-two.md"), []byte(file2), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(dir, "not-a-persona.txt"), []byte("ignored"), 0o644); err != nil { + t.Fatal(err) + } + + r := NewPersonaRegistry(dir) + if err := r.LoadAll(); err != nil { + t.Fatalf("LoadAll failed: %v", err) + } + + if len(r.Personas) != 2 { + t.Fatalf("expected 2 personas, got %d", len(r.Personas)) + } + + p1, err := r.Get("persona-one") + if err != nil { + t.Fatalf("Get persona-one failed: %v", err) + } + if p1.Description != "First persona" { + t.Errorf("unexpected description: %q", p1.Description) + } + if p1.Temperature != 0.3 { + t.Errorf("unexpected temperature: %f", p1.Temperature) + } + + p2, err := r.Get("persona-two") + if err != nil { + t.Fatalf("Get persona-two failed: %v", err) + } + if p2.CommunicationStyle != "detailed" { + t.Errorf("unexpected style: %q", p2.CommunicationStyle) + } +} + +func TestLoadAll_NonexistentDirectory(t *testing.T) { + r := NewPersonaRegistry("/tmp/nonexistent-persona-dir-xyz123") + err := r.LoadAll() + if err != nil { + t.Errorf("LoadAll should not error on nonexistent dir, got: %v", err) + } + if len(r.Personas) != 0 { + t.Error("should have no personas loaded") + } +} + +func TestParsePersonaFile_MissingFile(t *testing.T) { + _, err := ParsePersonaFile("/tmp/nonexistent-persona-file-xyz.md") + if err == nil { + t.Error("expected error for missing file") + } +} + +func TestParsePersonaFile_InvalidYAML(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "invalid.md") + + // No frontmatter at all + os.WriteFile(path, []byte("just plain text without frontmatter"), 0o644) + _, err := ParsePersonaFile(path) + if err == nil { + t.Error("expected error for content without frontmatter") + } +} + +func TestParsePersonaFile_MissingClosingFrontmatter(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "unclosed.md") + + content := "---\nname: broken\ndescription: no closing\n" + os.WriteFile(path, []byte(content), 0o644) + _, err := ParsePersonaFile(path) + if err == nil { + t.Error("expected error for missing closing frontmatter") + } +} + +func TestParsePersonaFile_NameFromFilename(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "my-custom-agent.md") + + content := "---\ndescription: has no name field\n---\nBody text" + os.WriteFile(path, []byte(content), 0o644) + + p, err := ParsePersonaFile(path) + if err != nil { + t.Fatal(err) + } + if p.Name != "my-custom-agent" { + t.Errorf("expected name from filename, got %q", p.Name) + } +} + +func TestNewPersonaRegistry_DefaultDir(t *testing.T) { + r := NewPersonaRegistry("") + if r.Dir == "" { + t.Error("default dir should not be empty") + } + if !strings.Contains(r.Dir, ".hawk") { + t.Errorf("default dir should contain .hawk, got %q", r.Dir) + } +} + +func TestEnsureBuiltins(t *testing.T) { + dir := t.TempDir() + r := NewPersonaRegistry(dir) + + if err := r.EnsureBuiltins(); err != nil { + t.Fatalf("EnsureBuiltins failed: %v", err) + } + + // Check files were created + entries, err := os.ReadDir(dir) + if err != nil { + t.Fatal(err) + } + + expectedFiles := map[string]bool{ + "default.md": false, + "reviewer.md": false, + "architect.md": false, + "debugger.md": false, + "teacher.md": false, + "speed.md": false, + } + + for _, e := range entries { + if _, ok := expectedFiles[e.Name()]; ok { + expectedFiles[e.Name()] = true + } + } + + for name, found := range expectedFiles { + if !found { + t.Errorf("expected built-in file %q not found", name) + } + } + + // Calling again should not overwrite existing files + // Modify a file and verify it is not overwritten + customContent := "---\nname: default\ndescription: custom\n---\nCustom prompt." + os.WriteFile(filepath.Join(dir, "default.md"), []byte(customContent), 0o644) + + if err := r.EnsureBuiltins(); err != nil { + t.Fatal(err) + } + + data, _ := os.ReadFile(filepath.Join(dir, "default.md")) + if !strings.Contains(string(data), "Custom prompt.") { + t.Error("EnsureBuiltins should not overwrite existing files") + } +} + +func TestBuildSystemPrompt_AllStyles(t *testing.T) { + styles := map[string]string{ + "concise": "brief and to the point", + "detailed": "thorough explanations", + "tutorial": "step by step", + "pair-programming": "Collaborate interactively", + } + + for style, expected := range styles { + p := &Persona{ + Name: "test", + SystemPrompt: "Base prompt.", + CommunicationStyle: style, + } + result := BuildSystemPrompt(p, "") + if !strings.Contains(result, expected) { + t.Errorf("style %q: expected to contain %q, got: %s", style, expected, result) + } + } +} + +func TestPersonaRegistry_ConcurrentAccess(t *testing.T) { + dir := t.TempDir() + r := NewPersonaRegistry(dir) + + // Pre-populate + for i := 0; i < 10; i++ { + name := fmt.Sprintf("persona-%d", i) + r.Personas[name] = &Persona{ + Name: name, + Expertise: []string{"backend"}, + } + } + + // Concurrent reads + done := make(chan bool, 20) + for i := 0; i < 10; i++ { + go func() { + _ = r.List() + done <- true + }() + go func(idx int) { + name := fmt.Sprintf("persona-%d", idx) + _, _ = r.Get(name) + done <- true + }(i) + } + + for i := 0; i < 20; i++ { + <-done + } +} + +func TestSelectPersona_MultipleKeywordMatch(t *testing.T) { + r := NewPersonaRegistry(t.TempDir()) + r.Personas["security"] = &Persona{ + Name: "security", + Expertise: []string{"security"}, + } + r.Personas["full-stack"] = &Persona{ + Name: "full-stack", + Expertise: []string{"security", "backend"}, + } + + // Task that matches both security and backend keywords + p := r.SelectPersona("fix SQL injection vulnerability in the API endpoint") + if p == nil { + t.Fatal("expected a persona match") + } + // full-stack should win because it matches both security + backend keywords + if p.Name != "full-stack" { + t.Errorf("expected full-stack (more keyword matches), got %q", p.Name) + } +} diff --git a/internal/multiagent/agents/persona_test.go b/internal/multiagent/agents/persona_test.go index 3f6a3597..a560ce19 100644 --- a/internal/multiagent/agents/persona_test.go +++ b/internal/multiagent/agents/persona_test.go @@ -1,7 +1,6 @@ package agents import ( - "fmt" "os" "path/filepath" "strings" @@ -200,665 +199,6 @@ func TestRenderPersonaFile_RoundTrip(t *testing.T) { } } -func TestSelectPersona_Security(t *testing.T) { - r := NewPersonaRegistry(t.TempDir()) - r.Personas["security"] = &Persona{ - Name: "security", - Expertise: []string{"security"}, - } - r.Personas["tester"] = &Persona{ - Name: "tester", - Expertise: []string{"testing"}, - } - r.Personas["devops"] = &Persona{ - Name: "devops", - Expertise: []string{"devops"}, - } - - // Security task - p := r.SelectPersona("fix security vulnerability in auth handler") - if p == nil || p.Name != "security" { - t.Errorf("expected security persona, got %v", p) - } -} - -func TestSelectPersona_Testing(t *testing.T) { - r := NewPersonaRegistry(t.TempDir()) - r.Personas["security"] = &Persona{ - Name: "security", - Expertise: []string{"security"}, - } - r.Personas["tester"] = &Persona{ - Name: "tester", - Expertise: []string{"testing"}, - } - - // Testing task - p := r.SelectPersona("write unit tests for the parser") - if p == nil || p.Name != "tester" { - t.Errorf("expected tester persona, got %v", p) - } -} - -func TestSelectPersona_DevOps(t *testing.T) { - r := NewPersonaRegistry(t.TempDir()) - r.Personas["devops"] = &Persona{ - Name: "devops", - Expertise: []string{"devops"}, - } - r.Personas["backend"] = &Persona{ - Name: "backend", - Expertise: []string{"backend"}, - } - - // DevOps task - p := r.SelectPersona("deploy to kubernetes cluster") - if p == nil || p.Name != "devops" { - t.Errorf("expected devops persona, got %v", p) - } -} - -func TestSelectPersona_FallsBackToDefault(t *testing.T) { - r := NewPersonaRegistry(t.TempDir()) - r.Personas["default"] = &Persona{ - Name: "default", - Expertise: []string{}, - } - r.Personas["security"] = &Persona{ - Name: "security", - Expertise: []string{"security"}, - } - - // No keyword match - p := r.SelectPersona("do something random and unrelated") - if p == nil || p.Name != "default" { - t.Errorf("expected default persona as fallback, got %v", p) - } -} - -func TestSelectPersona_NoMatch(t *testing.T) { - r := NewPersonaRegistry(t.TempDir()) - r.Personas["security"] = &Persona{ - Name: "security", - Expertise: []string{"security"}, - } - - // No match and no default - p := r.SelectPersona("play some music") - if p != nil { - t.Errorf("expected nil when no match and no default, got %v", p) - } -} - -func TestBuildSystemPrompt_IncludesAllComponents(t *testing.T) { - p := &Persona{ - Name: "test", - SystemPrompt: "You are a test assistant.", - Expertise: []string{"backend", "testing"}, - CommunicationStyle: "concise", - Rules: []string{"Rule one", "Rule two"}, - Examples: []PersonaExample{ - { - Input: "example input", - Output: "example output", - Context: "example context", - }, - }, - } - - result := BuildSystemPrompt(p, "This is a Go project using REST APIs.") - - // Should contain system prompt - if !strings.Contains(result, "You are a test assistant.") { - t.Error("should contain system prompt") - } - - // Should contain expertise - if !strings.Contains(result, "backend, testing") { - t.Error("should contain expertise") - } - - // Should contain communication style - if !strings.Contains(result, "brief and to the point") { - t.Error("should contain communication style for 'concise'") - } - - // Should contain rules - if !strings.Contains(result, "- Rule one") { - t.Error("should contain rules") - } - if !strings.Contains(result, "- Rule two") { - t.Error("should contain rule two") - } - - // Should contain examples - if !strings.Contains(result, "example input") { - t.Error("should contain example input") - } - if !strings.Contains(result, "example output") { - t.Error("should contain example output") - } - if !strings.Contains(result, "example context") { - t.Error("should contain example context") - } - - // Should contain project context - if !strings.Contains(result, "This is a Go project using REST APIs.") { - t.Error("should contain project context") - } -} - -func TestBuildSystemPrompt_EmptyPersona(t *testing.T) { - p := &Persona{Name: "empty"} - result := BuildSystemPrompt(p, "") - if result != "" { - t.Errorf("expected empty prompt for empty persona, got %q", result) - } -} - -func TestCreateGetDelete_Lifecycle(t *testing.T) { - dir := t.TempDir() - r := NewPersonaRegistry(dir) - - // Create - p := &Persona{ - Name: "lifecycle-test", - Description: "Test persona for lifecycle", - Model: "claude-sonnet-4-6", - Expertise: []string{"testing"}, - SystemPrompt: "You are a lifecycle test.", - } - if err := r.Create(p); err != nil { - t.Fatalf("Create failed: %v", err) - } - - // Verify file exists - path := filepath.Join(dir, "lifecycle-test.md") - if _, err := os.Stat(path); err != nil { - t.Fatalf("persona file not created: %v", err) - } - - // Get - got, err := r.Get("lifecycle-test") - if err != nil { - t.Fatalf("Get failed: %v", err) - } - if got.Name != "lifecycle-test" { - t.Errorf("expected name 'lifecycle-test', got %q", got.Name) - } - if got.Model != "claude-sonnet-4-6" { - t.Errorf("expected model 'claude-sonnet-4-6', got %q", got.Model) - } - - // Delete - if err := r.Delete("lifecycle-test"); err != nil { - t.Fatalf("Delete failed: %v", err) - } - - // Verify gone - if _, err := r.Get("lifecycle-test"); err == nil { - t.Error("expected error after delete, got nil") - } - if _, err := os.Stat(path); !os.IsNotExist(err) { - t.Error("persona file should be deleted from disk") - } -} - -func TestCreate_EmptyName(t *testing.T) { - dir := t.TempDir() - r := NewPersonaRegistry(dir) - - err := r.Create(&Persona{}) - if err == nil { - t.Error("expected error for empty name") - } -} - -func TestDelete_NotFound(t *testing.T) { - dir := t.TempDir() - r := NewPersonaRegistry(dir) - - err := r.Delete("nonexistent") - if err == nil { - t.Error("expected error deleting nonexistent persona") - } -} - -func TestList_ReturnsAllPersonas(t *testing.T) { - dir := t.TempDir() - r := NewPersonaRegistry(dir) - - r.Personas["alpha"] = &Persona{Name: "alpha"} - r.Personas["beta"] = &Persona{Name: "beta"} - r.Personas["gamma"] = &Persona{Name: "gamma"} - - list := r.List() - if len(list) != 3 { - t.Fatalf("expected 3 personas, got %d", len(list)) - } - - // Should be sorted - if list[0].Name != "alpha" || list[1].Name != "beta" || list[2].Name != "gamma" { - t.Errorf("expected sorted order, got %s, %s, %s", list[0].Name, list[1].Name, list[2].Name) - } -} - -func TestList_Empty(t *testing.T) { - dir := t.TempDir() - r := NewPersonaRegistry(dir) - - list := r.List() - if len(list) != 0 { - t.Errorf("expected empty list, got %d", len(list)) - } -} - -func TestBuiltinPersonas_AreValid(t *testing.T) { - builtins := BuiltinPersonas() - - expectedNames := map[string]bool{ - "default": false, - "reviewer": false, - "architect": false, - "debugger": false, - "teacher": false, - "speed": false, - "planner": false, - "executor": false, - "critic": false, - "security-reviewer": false, - "test-engineer": false, - "tracer": false, - "verifier": false, - "validator": false, - "integrator": false, - "documenter": false, - "devops": false, - "performance": false, - "refactorer": false, - "cavecrew-investigator": false, - "cavecrew-builder": false, - "cavecrew-reviewer": false, - } - - for _, p := range builtins { - if p.Name == "" { - t.Error("built-in persona has empty name") - } - if p.Description == "" { - t.Errorf("built-in persona %q has empty description", p.Name) - } - if p.SystemPrompt == "" { - t.Errorf("built-in persona %q has empty system prompt", p.Name) - } - if len(p.Expertise) == 0 { - t.Errorf("built-in persona %q has no expertise", p.Name) - } - if p.CommunicationStyle == "" { - t.Errorf("built-in persona %q has no communication style", p.Name) - } - if p.CreatedAt.IsZero() { - t.Errorf("built-in persona %q has zero CreatedAt", p.Name) - } - if _, ok := expectedNames[p.Name]; ok { - expectedNames[p.Name] = true - } else { - t.Errorf("unexpected built-in persona: %q", p.Name) - } - } - - for name, found := range expectedNames { - if !found { - t.Errorf("expected built-in persona %q not found", name) - } - } -} - -func TestSelectPersona_NewDomains(t *testing.T) { - r := NewPersonaRegistry(t.TempDir()) - for _, p := range BuiltinPersonas() { - r.Personas[p.Name] = p - } - - cases := []struct { - task string - wantDomain string // expertise the selected persona should include - }{ - {"profile and optimize this slow benchmark with high latency", "performance"}, - {"refactor this module to reduce technical debt and simplify", "refactoring"}, - {"write the readme and api docs with a tutorial guide", "documentation"}, - {"add observability: trace spans and structured logging", "tracing"}, - } - - for _, c := range cases { - p := r.SelectPersona(c.task) - if p == nil { - t.Errorf("task %q selected nil persona", c.task) - continue - } - found := false - for _, e := range p.Expertise { - if e == c.wantDomain { - found = true - } - } - if !found { - t.Errorf("task %q selected %q (expertise %v), expected domain %q", - c.task, p.Name, p.Expertise, c.wantDomain) - } - } -} - -func TestBuiltinPersonas_Count(t *testing.T) { - if got := len(BuiltinPersonas()); got != 22 { - t.Errorf("expected 22 built-in personas, got %d", got) - } -} - -func TestCavecrewPersonas_ReturnsThree(t *testing.T) { - crew := CavecrewPersonas() - if len(crew) != 3 { - t.Fatalf("expected 3 cavecrew personas, got %d", len(crew)) - } - want := []string{"cavecrew-investigator", "cavecrew-builder", "cavecrew-reviewer"} - for i, p := range crew { - if p.Name != want[i] { - t.Errorf("expected %d-th persona %q, got %q", i, want[i], p.Name) - } - if p.Description == "" { - t.Errorf("cavecrew persona %q has empty description", p.Name) - } - if p.SystemPrompt == "" { - t.Errorf("cavecrew persona %q has empty system prompt", p.Name) - } - if len(p.Rules) == 0 { - t.Errorf("cavecrew persona %q has no rules", p.Name) - } - } -} - -func TestCavecrewPersonas_AreInBuiltinList(t *testing.T) { - // Cavecrew personas must be a subset of BuiltinPersonas so - // EnsureBuiltins auto-creates them on first run. - builtins := map[string]bool{} - for _, p := range BuiltinPersonas() { - builtins[p.Name] = true - } - for _, p := range CavecrewPersonas() { - if !builtins[p.Name] { - t.Errorf("cavecrew persona %q missing from BuiltinPersonas", p.Name) - } - } -} - -func TestEnsureCavecrew_WritesFiles(t *testing.T) { - dir := t.TempDir() - r := NewPersonaRegistry(dir) - if err := r.EnsureCavecrew(); err != nil { - t.Fatalf("EnsureCavecrew: %v", err) - } - for _, want := range []string{"cavecrew-investigator.md", "cavecrew-builder.md", "cavecrew-reviewer.md"} { - path := filepath.Join(dir, want) - if _, err := os.Stat(path); err != nil { - t.Errorf("expected file %s: %v", want, err) - } - } -} - -func TestLoadAll_FromDirectory(t *testing.T) { - dir := t.TempDir() - - // Write some persona files - file1 := `--- -name: persona-one -description: First persona -expertise: [backend] -style: concise -temperature: 0.3 ---- -You are persona one. -` - file2 := `--- -name: persona-two -description: Second persona -expertise: [frontend] -style: detailed -temperature: 0.7 ---- -You are persona two. -` - if err := os.WriteFile(filepath.Join(dir, "persona-one.md"), []byte(file1), 0o644); err != nil { - t.Fatal(err) - } - if err := os.WriteFile(filepath.Join(dir, "persona-two.md"), []byte(file2), 0o644); err != nil { - t.Fatal(err) - } - if err := os.WriteFile(filepath.Join(dir, "not-a-persona.txt"), []byte("ignored"), 0o644); err != nil { - t.Fatal(err) - } - - r := NewPersonaRegistry(dir) - if err := r.LoadAll(); err != nil { - t.Fatalf("LoadAll failed: %v", err) - } - - if len(r.Personas) != 2 { - t.Fatalf("expected 2 personas, got %d", len(r.Personas)) - } - - p1, err := r.Get("persona-one") - if err != nil { - t.Fatalf("Get persona-one failed: %v", err) - } - if p1.Description != "First persona" { - t.Errorf("unexpected description: %q", p1.Description) - } - if p1.Temperature != 0.3 { - t.Errorf("unexpected temperature: %f", p1.Temperature) - } - - p2, err := r.Get("persona-two") - if err != nil { - t.Fatalf("Get persona-two failed: %v", err) - } - if p2.CommunicationStyle != "detailed" { - t.Errorf("unexpected style: %q", p2.CommunicationStyle) - } -} - -func TestLoadAll_NonexistentDirectory(t *testing.T) { - r := NewPersonaRegistry("/tmp/nonexistent-persona-dir-xyz123") - err := r.LoadAll() - if err != nil { - t.Errorf("LoadAll should not error on nonexistent dir, got: %v", err) - } - if len(r.Personas) != 0 { - t.Error("should have no personas loaded") - } -} - -func TestParsePersonaFile_MissingFile(t *testing.T) { - _, err := ParsePersonaFile("/tmp/nonexistent-persona-file-xyz.md") - if err == nil { - t.Error("expected error for missing file") - } -} - -func TestParsePersonaFile_InvalidYAML(t *testing.T) { - dir := t.TempDir() - path := filepath.Join(dir, "invalid.md") - - // No frontmatter at all - os.WriteFile(path, []byte("just plain text without frontmatter"), 0o644) - _, err := ParsePersonaFile(path) - if err == nil { - t.Error("expected error for content without frontmatter") - } -} - -func TestParsePersonaFile_MissingClosingFrontmatter(t *testing.T) { - dir := t.TempDir() - path := filepath.Join(dir, "unclosed.md") - - content := "---\nname: broken\ndescription: no closing\n" - os.WriteFile(path, []byte(content), 0o644) - _, err := ParsePersonaFile(path) - if err == nil { - t.Error("expected error for missing closing frontmatter") - } -} - -func TestParsePersonaFile_NameFromFilename(t *testing.T) { - dir := t.TempDir() - path := filepath.Join(dir, "my-custom-agent.md") - - content := "---\ndescription: has no name field\n---\nBody text" - os.WriteFile(path, []byte(content), 0o644) - - p, err := ParsePersonaFile(path) - if err != nil { - t.Fatal(err) - } - if p.Name != "my-custom-agent" { - t.Errorf("expected name from filename, got %q", p.Name) - } -} - -func TestNewPersonaRegistry_DefaultDir(t *testing.T) { - r := NewPersonaRegistry("") - if r.Dir == "" { - t.Error("default dir should not be empty") - } - if !strings.Contains(r.Dir, ".hawk") { - t.Errorf("default dir should contain .hawk, got %q", r.Dir) - } -} - -func TestEnsureBuiltins(t *testing.T) { - dir := t.TempDir() - r := NewPersonaRegistry(dir) - - if err := r.EnsureBuiltins(); err != nil { - t.Fatalf("EnsureBuiltins failed: %v", err) - } - - // Check files were created - entries, err := os.ReadDir(dir) - if err != nil { - t.Fatal(err) - } - - expectedFiles := map[string]bool{ - "default.md": false, - "reviewer.md": false, - "architect.md": false, - "debugger.md": false, - "teacher.md": false, - "speed.md": false, - } - - for _, e := range entries { - if _, ok := expectedFiles[e.Name()]; ok { - expectedFiles[e.Name()] = true - } - } - - for name, found := range expectedFiles { - if !found { - t.Errorf("expected built-in file %q not found", name) - } - } - - // Calling again should not overwrite existing files - // Modify a file and verify it is not overwritten - customContent := "---\nname: default\ndescription: custom\n---\nCustom prompt." - os.WriteFile(filepath.Join(dir, "default.md"), []byte(customContent), 0o644) - - if err := r.EnsureBuiltins(); err != nil { - t.Fatal(err) - } - - data, _ := os.ReadFile(filepath.Join(dir, "default.md")) - if !strings.Contains(string(data), "Custom prompt.") { - t.Error("EnsureBuiltins should not overwrite existing files") - } -} - -func TestBuildSystemPrompt_AllStyles(t *testing.T) { - styles := map[string]string{ - "concise": "brief and to the point", - "detailed": "thorough explanations", - "tutorial": "step by step", - "pair-programming": "Collaborate interactively", - } - - for style, expected := range styles { - p := &Persona{ - Name: "test", - SystemPrompt: "Base prompt.", - CommunicationStyle: style, - } - result := BuildSystemPrompt(p, "") - if !strings.Contains(result, expected) { - t.Errorf("style %q: expected to contain %q, got: %s", style, expected, result) - } - } -} - -func TestPersonaRegistry_ConcurrentAccess(t *testing.T) { - dir := t.TempDir() - r := NewPersonaRegistry(dir) - - // Pre-populate - for i := 0; i < 10; i++ { - name := fmt.Sprintf("persona-%d", i) - r.Personas[name] = &Persona{ - Name: name, - Expertise: []string{"backend"}, - } - } - - // Concurrent reads - done := make(chan bool, 20) - for i := 0; i < 10; i++ { - go func() { - _ = r.List() - done <- true - }() - go func(idx int) { - name := fmt.Sprintf("persona-%d", idx) - _, _ = r.Get(name) - done <- true - }(i) - } - - for i := 0; i < 20; i++ { - <-done - } -} - -func TestSelectPersona_MultipleKeywordMatch(t *testing.T) { - r := NewPersonaRegistry(t.TempDir()) - r.Personas["security"] = &Persona{ - Name: "security", - Expertise: []string{"security"}, - } - r.Personas["full-stack"] = &Persona{ - Name: "full-stack", - Expertise: []string{"security", "backend"}, - } - - // Task that matches both security and backend keywords - p := r.SelectPersona("fix SQL injection vulnerability in the API endpoint") - if p == nil { - t.Fatal("expected a persona match") - } - // full-stack should win because it matches both security + backend keywords - if p.Name != "full-stack" { - t.Errorf("expected full-stack (more keyword matches), got %q", p.Name) - } -} - func TestParseYAMLList(t *testing.T) { tests := []struct { input string From 1b72e0ec5447e8c6078941f66329938fdb747461 Mon Sep 17 00:00:00 2001 From: Lakshman Patel Date: Fri, 19 Jun 2026 20:41:18 +0530 Subject: [PATCH 19/20] test(tool): split commitlint config/check tests into commitlint_config_test.go --- internal/tool/commitlint_config_test.go | 634 ++++++++++++++++++++++++ internal/tool/commitlint_test.go | 628 ----------------------- 2 files changed, 634 insertions(+), 628 deletions(-) create mode 100644 internal/tool/commitlint_config_test.go diff --git a/internal/tool/commitlint_config_test.go b/internal/tool/commitlint_config_test.go new file mode 100644 index 00000000..deb657f5 --- /dev/null +++ b/internal/tool/commitlint_config_test.go @@ -0,0 +1,634 @@ +package tool + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestLoadFromProject_JSONConfig(t *testing.T) { + dir := t.TempDir() + config := `{ + "extends": ["@commitlint/config-conventional"], + "rules": { + "type-enum": [2, "always", ["feat", "fix", "chore"]], + "subject-max-length": [2, "always", 50], + "header-max-length": [1, "always", 80] + } +}` + err := os.WriteFile(filepath.Join(dir, ".commitlintrc.json"), []byte(config), 0o644) + if err != nil { + t.Fatal(err) + } + + linter := NewCommitLinter() + if err := linter.LoadFromProject(dir); err != nil { + t.Fatalf("LoadFromProject() error: %v", err) + } + + // Verify type-enum was updated. + for _, rule := range linter.Rules { + if rule.Name == "type-enum" { + allowed, ok := rule.Value.([]string) + if !ok { + t.Fatalf("type-enum value is not []string: %T", rule.Value) + } + if len(allowed) != 3 { + t.Errorf("expected 3 allowed types, got %d: %v", len(allowed), allowed) + } + break + } + } + + // Verify subject-max-length was updated. + for _, rule := range linter.Rules { + if rule.Name == "subject-max-length" { + v, ok := rule.Value.(int) + if !ok { + t.Fatalf("subject-max-length value is not int: %T", rule.Value) + } + if v != 50 { + t.Errorf("expected subject-max-length = 50, got %d", v) + } + break + } + } + + // Verify header-max-length is now a warning. + for _, rule := range linter.Rules { + if rule.Name == "header-max-length" { + if rule.Level != "warning" { + t.Errorf("expected header-max-length level = warning, got %q", rule.Level) + } + break + } + } +} + +func TestLoadFromProject_JSConfig(t *testing.T) { + dir := t.TempDir() + config := `module.exports = { + extends: ['@commitlint/config-conventional'], + rules: { + 'type-enum': [2, 'always', ['feat', 'fix', 'docs', 'refactor']], + 'subject-max-length': [2, 'always', 60], + }, +}; +` + err := os.WriteFile(filepath.Join(dir, "commitlint.config.js"), []byte(config), 0o644) + if err != nil { + t.Fatal(err) + } + + linter := NewCommitLinter() + if err := linter.LoadFromProject(dir); err != nil { + t.Fatalf("LoadFromProject() error: %v", err) + } + + // Verify type-enum was updated from JS config. + for _, rule := range linter.Rules { + if rule.Name == "type-enum" { + allowed, ok := rule.Value.([]string) + if !ok { + t.Fatalf("type-enum value is not []string: %T", rule.Value) + } + if len(allowed) != 4 { + t.Errorf("expected 4 allowed types, got %d: %v", len(allowed), allowed) + } + break + } + } +} + +func TestLoadFromProject_NoConfig(t *testing.T) { + dir := t.TempDir() + linter := NewCommitLinter() + err := linter.LoadFromProject(dir) + if err == nil { + t.Error("expected error when no config found") + } +} + +func TestLint_DisabledRule(t *testing.T) { + linter := NewCommitLinter() + // Disable type-enum. + for i, rule := range linter.Rules { + if rule.Name == "type-enum" { + linter.Rules[i].Level = "disabled" + break + } + } + + result := linter.Lint("yolo: anything goes") + // Should not have type-enum error since it's disabled. + for _, e := range result.Errors { + if strings.Contains(e, "type-enum") { + t.Errorf("disabled rule should not produce errors, got: %s", e) + } + } +} + +func TestLint_AllValidTypes(t *testing.T) { + linter := NewCommitLinter() + types := []string{"feat", "fix", "docs", "style", "refactor", "perf", "test", "build", "ci", "chore", "revert"} + + for _, typ := range types { + msg := typ + ": some description" + result := linter.Lint(msg) + for _, e := range result.Errors { + if strings.Contains(e, "type-enum") { + t.Errorf("type %q should be valid, got error: %s", typ, e) + } + } + } +} + +func TestParseCommitMessage_Empty(t *testing.T) { + parsed := ParseCommitMessage("") + if parsed.Type != "" || parsed.Scope != "" || parsed.Subject != "" { + t.Errorf("expected empty ParsedCommit for empty input, got: %+v", parsed) + } +} + +func TestParseCommitMessage_NoColon(t *testing.T) { + parsed := ParseCommitMessage("just a plain message") + if parsed.Subject != "just a plain message" { + t.Errorf("expected subject = %q, got %q", "just a plain message", parsed.Subject) + } +} + +func TestLint_BreakingChange(t *testing.T) { + linter := NewCommitLinter() + result := linter.Lint("feat!: breaking API change") + + if !result.Valid { + t.Errorf("breaking change commit should be valid, errors: %v", result.Errors) + } + + parsed := ParseCommitMessage("feat!: breaking API change") + if !parsed.Breaking { + t.Error("expected Breaking = true for '!' indicator") + } +} + +func TestCommitLinter_ConcurrentAccess(t *testing.T) { + linter := NewCommitLinter() + done := make(chan bool, 10) + + for i := 0; i < 10; i++ { + go func() { + result := linter.Lint("feat: concurrent test") + if !result.Valid { + t.Errorf("concurrent lint failed: %v", result.Errors) + } + done <- true + }() + } + + for i := 0; i < 10; i++ { + <-done + } +} + +func TestIsFooter(t *testing.T) { + tests := []struct { + input string + want bool + }{ + {"Refs: #123", true}, + {"BREAKING CHANGE: removed old API", true}, + {"Reviewed-by: Alice", true}, + {"just some body text", false}, + {"", false}, + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := isFooter(tt.input) + if got != tt.want { + t.Errorf("isFooter(%q) = %v, want %v", tt.input, got, tt.want) + } + }) + } +} + +func TestInferTypeFromContent(t *testing.T) { + tests := []struct { + subject string + body string + want string + }{ + {"fix the login bug", "", "fix"}, + {"resolve null pointer error", "", "fix"}, + {"add new user authentication", "", "feat"}, + {"implement OAuth2 flow", "", "feat"}, + {"update readme", "", "docs"}, + {"add unit tests", "", "test"}, + {"restructure the codebase", "", "refactor"}, + {"format code with gofmt", "", "style"}, + {"optimize query performance", "", "perf"}, + {"bump dependency version", "", "chore"}, + {"", "", "chore"}, + } + for _, tt := range tests { + t.Run(tt.subject, func(t *testing.T) { + got := inferTypeFromContent(tt.subject, tt.body) + if got != tt.want { + t.Errorf("inferTypeFromContent(%q, %q) = %q, want %q", tt.subject, tt.body, got, tt.want) + } + }) + } +} + +func TestSplitRuleValue(t *testing.T) { + tests := []struct { + input string + want int + }{ + {"2, always, [feat,fix]", 3}, + {"2, always", 2}, + {"2, always, 50", 3}, + {"single", 1}, + {"", 0}, + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := splitRuleValue(tt.input) + if len(got) != tt.want { + t.Errorf("splitRuleValue(%q) returned %d parts, want %d: %v", tt.input, len(got), tt.want, got) + } + }) + } +} + +func TestParseJSValue(t *testing.T) { + tests := []struct { + input string + want interface{} + }{ + {"['feat','fix','docs']", []string{"feat", "fix", "docs"}}, + {"42", 42}, + {"'lowercase'", "lowercase"}, + {`"lowercase"`, "lowercase"}, + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := parseJSValue(tt.input) + switch want := tt.want.(type) { + case []string: + gotSlice, ok := got.([]string) + if !ok { + t.Errorf("parseJSValue(%q) returned %T, want []string", tt.input, got) + return + } + if len(gotSlice) != len(want) { + t.Errorf("parseJSValue(%q) returned %d items, want %d", tt.input, len(gotSlice), len(want)) + } + case int: + gotInt, ok := got.(int) + if !ok { + t.Errorf("parseJSValue(%q) returned %T, want int", tt.input, got) + return + } + if gotInt != want { + t.Errorf("parseJSValue(%q) = %d, want %d", tt.input, gotInt, want) + } + case string: + gotStr, ok := got.(string) + if !ok { + t.Errorf("parseJSValue(%q) returned %T, want string", tt.input, got) + return + } + if gotStr != want { + t.Errorf("parseJSValue(%q) = %q, want %q", tt.input, gotStr, want) + } + } + }) + } +} + +func TestIsLowerCase(t *testing.T) { + tests := []struct { + input string + want bool + }{ + {"feat", true}, + {"Feat", false}, + {"FEAT", false}, + {"", true}, + {"fix123", true}, + {"FIX", false}, + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := isLowerCase(tt.input) + if got != tt.want { + t.Errorf("isLowerCase(%q) = %v, want %v", tt.input, got, tt.want) + } + }) + } +} + +func TestLoadFromProject_YAMLConfig(t *testing.T) { + dir := t.TempDir() + config := `extends: + - '@commitlint/config-conventional' +rules: + type-enum: [2, always, 'feat|fix|chore'] + subject-max-length: [2, always, 50] +` + err := os.WriteFile(filepath.Join(dir, ".commitlintrc.yml"), []byte(config), 0o644) + if err != nil { + t.Fatal(err) + } + + linter := NewCommitLinter() + if err := linter.LoadFromProject(dir); err != nil { + t.Fatalf("LoadFromProject() error: %v", err) + } + + // Verify subject-max-length was updated + for _, rule := range linter.Rules { + if rule.Name == "subject-max-length" { + v, ok := rule.Value.(int) + if !ok { + t.Fatalf("subject-max-length value is not int: %T", rule.Value) + } + if v != 50 { + t.Errorf("expected subject-max-length = 50, got %d", v) + } + break + } + } +} + +func TestLoadFromProject_YAMLAlternate(t *testing.T) { + dir := t.TempDir() + config := `rules: + header-max-length: [1, always, 80] +` + err := os.WriteFile(filepath.Join(dir, ".commitlintrc.yaml"), []byte(config), 0o644) + if err != nil { + t.Fatal(err) + } + + linter := NewCommitLinter() + if err := linter.LoadFromProject(dir); err != nil { + t.Fatalf("LoadFromProject() error: %v", err) + } + + // Verify header-max-length was updated + for _, rule := range linter.Rules { + if rule.Name == "header-max-length" { + if rule.Level != "warning" { + t.Errorf("expected header-max-length level = warning, got %q", rule.Level) + } + break + } + } +} + +func TestLint_UnknownRule(t *testing.T) { + linter := NewCommitLinter() + linter.Rules = append(linter.Rules, CommitRule{ + Name: "unknown-rule", + Level: "error", + Applicable: "always", + }) + // Unknown rules should be silently ignored + result := linter.Lint("feat: add login") + if !result.Valid { + t.Errorf("unknown rule should not cause validation failure: %v", result.Errors) + } +} + +func TestCheckSubjectMaxLength_Float64Value(t *testing.T) { + linter := NewCommitLinter() + // Override with float64 value (as JSON parsing might produce) + linter.Rules = []CommitRule{ + {Name: "subject-max-length", Level: "error", Applicable: "always", Value: float64(50)}, + } + msg := "feat: " + strings.Repeat("a", 60) + result := linter.Lint(msg) + found := false + for _, e := range result.Errors { + if strings.Contains(e, "subject-max-length") { + found = true + } + } + if !found { + t.Error("expected subject-max-length error with float64 value") + } +} + +func TestCheckBodyMaxLineLength_Float64Value(t *testing.T) { + linter := NewCommitLinter() + linter.Rules = []CommitRule{ + {Name: "body-max-line-length", Level: "warning", Applicable: "always", Value: float64(80)}, + } + longLine := strings.Repeat("x", 100) + msg := "feat: add\n\n" + longLine + result := linter.Lint(msg) + found := false + for _, w := range result.Warnings { + if strings.Contains(w, "body-max-line-length") { + found = true + } + } + if !found { + t.Error("expected body-max-line-length warning with float64 value") + } +} + +func TestCheckHeaderMaxLength_Float64Value(t *testing.T) { + linter := NewCommitLinter() + linter.Rules = []CommitRule{ + {Name: "header-max-length", Level: "error", Applicable: "always", Value: float64(50)}, + } + msg := "feat: " + strings.Repeat("b", 60) + result := linter.Lint(msg) + found := false + for _, e := range result.Errors { + if strings.Contains(e, "header-max-length") { + found = true + } + } + if !found { + t.Error("expected header-max-length error with float64 value") + } +} + +func TestCheckFooterMaxLineLength_Float64Value(t *testing.T) { + linter := NewCommitLinter() + linter.Rules = []CommitRule{ + {Name: "footer-max-line-length", Level: "warning", Applicable: "always", Value: float64(80)}, + } + longFooter := strings.Repeat("z", 100) + msg := "feat: add\n\nBody.\n\nRefs: #123\n" + longFooter + result := linter.Lint(msg) + found := false + for _, w := range result.Warnings { + if strings.Contains(w, "footer-max-line-length") { + found = true + } + } + if !found { + t.Error("expected footer-max-line-length warning with float64 value") + } +} + +func TestCheckTypeEnum_NeverApplicable(t *testing.T) { + linter := NewCommitLinter() + for i, rule := range linter.Rules { + if rule.Name == "type-enum" { + linter.Rules[i].Applicable = "never" + break + } + } + result := linter.Lint("feat: add login") + found := false + for _, e := range result.Errors { + if strings.Contains(e, "type-enum") { + found = true + } + } + if !found { + t.Error("expected type-enum error with 'never' applicable") + } +} + +func TestCheckTypeCase_EmptyCaseType(t *testing.T) { + linter := NewCommitLinter() + for i, rule := range linter.Rules { + if rule.Name == "type-case" { + linter.Rules[i].Value = "" + break + } + } + // Should default to lowercase check + result := linter.Lint("Feat: add something") + if result.Valid { + t.Error("expected invalid result for uppercase type with empty case type") + } +} + +func TestCheckScopeCase_Uppercase(t *testing.T) { + linter := NewCommitLinter() + for i, rule := range linter.Rules { + if rule.Name == "scope-case" { + linter.Rules[i].Value = "uppercase" + break + } + } + result := linter.Lint("feat(AUTH): add login") + found := false + for _, e := range result.Errors { + if strings.Contains(e, "scope-case") { + found = true + } + } + // With uppercase rule, lowercase scope should fail + if found { + t.Error("unexpected scope-case error for uppercase scope with uppercase rule") + } +} + +func TestCheckSubjectMaxLength_EmptyCaseType(t *testing.T) { + linter := NewCommitLinter() + for i, rule := range linter.Rules { + if rule.Name == "subject-max-length" { + linter.Rules[i].Value = nil + break + } + } + // Should use default 72 + result := linter.Lint("feat: " + strings.Repeat("a", 80)) + found := false + for _, e := range result.Errors { + if strings.Contains(e, "subject-max-length") { + found = true + } + } + if !found { + t.Error("expected subject-max-length error with nil value (default 72)") + } +} + +func TestParseCommitMessage_JustType(t *testing.T) { + parsed := ParseCommitMessage("feat:") + if parsed.Type != "feat" { + t.Errorf("Type = %q, want %q", parsed.Type, "feat") + } + if parsed.Subject != "" { + t.Errorf("Subject = %q, want empty", parsed.Subject) + } +} + +func TestParseCommitMessage_BreakingInBody(t *testing.T) { + msg := "feat: add auth\n\nImplement OAuth2.\n\nBREAKING CHANGE: removed legacy" + parsed := ParseCommitMessage(msg) + if !parsed.Breaking { + t.Error("expected Breaking = true for BREAKING CHANGE in body") + } +} + +func TestParseCommitMessage_BreakingDashChange(t *testing.T) { + msg := "feat: add auth\n\nBREAKING-CHANGE: removed legacy" + parsed := ParseCommitMessage(msg) + if !parsed.Breaking { + t.Error("expected Breaking = true for BREAKING-CHANGE") + } +} + +func TestFixMessage_WithScope(t *testing.T) { + linter := NewCommitLinter() + fixed := linter.FixMessage("Feat(Auth): add login") + if !strings.HasPrefix(fixed, "feat(auth):") { + t.Errorf("expected 'feat(auth):' prefix, got: %q", fixed) + } +} + +func TestFixMessage_LongHeader(t *testing.T) { + linter := NewCommitLinter() + long := "feat: " + strings.Repeat("a", 120) + fixed := linter.FixMessage(long) + if len(fixed) > 100 { + t.Errorf("expected header truncated to 100 chars, got %d: %q", len(fixed), fixed) + } +} + +func TestUpdateRule_NewRule(t *testing.T) { + linter := NewCommitLinter() + initialCount := len(linter.Rules) + linter.updateRule(CommitRule{ + Name: "custom-rule", + Level: "error", + Applicable: "always", + Value: "test", + }) + if len(linter.Rules) != initialCount+1 { + t.Errorf("expected %d rules, got %d", initialCount+1, len(linter.Rules)) + } +} + +func TestUpdateRule_OverrideExisting(t *testing.T) { + linter := NewCommitLinter() + linter.updateRule(CommitRule{ + Name: "type-enum", + Level: "warning", + Applicable: "never", + Value: []string{"a", "b"}, + }) + for _, rule := range linter.Rules { + if rule.Name == "type-enum" { + if rule.Level != "warning" { + t.Errorf("expected level 'warning', got %q", rule.Level) + } + if rule.Applicable != "never" { + t.Errorf("expected applicable 'never', got %q", rule.Applicable) + } + return + } + } + t.Error("type-enum rule not found") +} diff --git a/internal/tool/commitlint_test.go b/internal/tool/commitlint_test.go index a2023a2a..1f016083 100644 --- a/internal/tool/commitlint_test.go +++ b/internal/tool/commitlint_test.go @@ -1,8 +1,6 @@ package tool import ( - "os" - "path/filepath" "strings" "testing" @@ -377,629 +375,3 @@ func TestFormatLintResult_WithWarnings(t *testing.T) { t.Errorf("expected warning content, got: %q", output) } } - -func TestLoadFromProject_JSONConfig(t *testing.T) { - dir := t.TempDir() - config := `{ - "extends": ["@commitlint/config-conventional"], - "rules": { - "type-enum": [2, "always", ["feat", "fix", "chore"]], - "subject-max-length": [2, "always", 50], - "header-max-length": [1, "always", 80] - } -}` - err := os.WriteFile(filepath.Join(dir, ".commitlintrc.json"), []byte(config), 0o644) - if err != nil { - t.Fatal(err) - } - - linter := NewCommitLinter() - if err := linter.LoadFromProject(dir); err != nil { - t.Fatalf("LoadFromProject() error: %v", err) - } - - // Verify type-enum was updated. - for _, rule := range linter.Rules { - if rule.Name == "type-enum" { - allowed, ok := rule.Value.([]string) - if !ok { - t.Fatalf("type-enum value is not []string: %T", rule.Value) - } - if len(allowed) != 3 { - t.Errorf("expected 3 allowed types, got %d: %v", len(allowed), allowed) - } - break - } - } - - // Verify subject-max-length was updated. - for _, rule := range linter.Rules { - if rule.Name == "subject-max-length" { - v, ok := rule.Value.(int) - if !ok { - t.Fatalf("subject-max-length value is not int: %T", rule.Value) - } - if v != 50 { - t.Errorf("expected subject-max-length = 50, got %d", v) - } - break - } - } - - // Verify header-max-length is now a warning. - for _, rule := range linter.Rules { - if rule.Name == "header-max-length" { - if rule.Level != "warning" { - t.Errorf("expected header-max-length level = warning, got %q", rule.Level) - } - break - } - } -} - -func TestLoadFromProject_JSConfig(t *testing.T) { - dir := t.TempDir() - config := `module.exports = { - extends: ['@commitlint/config-conventional'], - rules: { - 'type-enum': [2, 'always', ['feat', 'fix', 'docs', 'refactor']], - 'subject-max-length': [2, 'always', 60], - }, -}; -` - err := os.WriteFile(filepath.Join(dir, "commitlint.config.js"), []byte(config), 0o644) - if err != nil { - t.Fatal(err) - } - - linter := NewCommitLinter() - if err := linter.LoadFromProject(dir); err != nil { - t.Fatalf("LoadFromProject() error: %v", err) - } - - // Verify type-enum was updated from JS config. - for _, rule := range linter.Rules { - if rule.Name == "type-enum" { - allowed, ok := rule.Value.([]string) - if !ok { - t.Fatalf("type-enum value is not []string: %T", rule.Value) - } - if len(allowed) != 4 { - t.Errorf("expected 4 allowed types, got %d: %v", len(allowed), allowed) - } - break - } - } -} - -func TestLoadFromProject_NoConfig(t *testing.T) { - dir := t.TempDir() - linter := NewCommitLinter() - err := linter.LoadFromProject(dir) - if err == nil { - t.Error("expected error when no config found") - } -} - -func TestLint_DisabledRule(t *testing.T) { - linter := NewCommitLinter() - // Disable type-enum. - for i, rule := range linter.Rules { - if rule.Name == "type-enum" { - linter.Rules[i].Level = "disabled" - break - } - } - - result := linter.Lint("yolo: anything goes") - // Should not have type-enum error since it's disabled. - for _, e := range result.Errors { - if strings.Contains(e, "type-enum") { - t.Errorf("disabled rule should not produce errors, got: %s", e) - } - } -} - -func TestLint_AllValidTypes(t *testing.T) { - linter := NewCommitLinter() - types := []string{"feat", "fix", "docs", "style", "refactor", "perf", "test", "build", "ci", "chore", "revert"} - - for _, typ := range types { - msg := typ + ": some description" - result := linter.Lint(msg) - for _, e := range result.Errors { - if strings.Contains(e, "type-enum") { - t.Errorf("type %q should be valid, got error: %s", typ, e) - } - } - } -} - -func TestParseCommitMessage_Empty(t *testing.T) { - parsed := ParseCommitMessage("") - if parsed.Type != "" || parsed.Scope != "" || parsed.Subject != "" { - t.Errorf("expected empty ParsedCommit for empty input, got: %+v", parsed) - } -} - -func TestParseCommitMessage_NoColon(t *testing.T) { - parsed := ParseCommitMessage("just a plain message") - if parsed.Subject != "just a plain message" { - t.Errorf("expected subject = %q, got %q", "just a plain message", parsed.Subject) - } -} - -func TestLint_BreakingChange(t *testing.T) { - linter := NewCommitLinter() - result := linter.Lint("feat!: breaking API change") - - if !result.Valid { - t.Errorf("breaking change commit should be valid, errors: %v", result.Errors) - } - - parsed := ParseCommitMessage("feat!: breaking API change") - if !parsed.Breaking { - t.Error("expected Breaking = true for '!' indicator") - } -} - -func TestCommitLinter_ConcurrentAccess(t *testing.T) { - linter := NewCommitLinter() - done := make(chan bool, 10) - - for i := 0; i < 10; i++ { - go func() { - result := linter.Lint("feat: concurrent test") - if !result.Valid { - t.Errorf("concurrent lint failed: %v", result.Errors) - } - done <- true - }() - } - - for i := 0; i < 10; i++ { - <-done - } -} - -func TestIsFooter(t *testing.T) { - tests := []struct { - input string - want bool - }{ - {"Refs: #123", true}, - {"BREAKING CHANGE: removed old API", true}, - {"Reviewed-by: Alice", true}, - {"just some body text", false}, - {"", false}, - } - for _, tt := range tests { - t.Run(tt.input, func(t *testing.T) { - got := isFooter(tt.input) - if got != tt.want { - t.Errorf("isFooter(%q) = %v, want %v", tt.input, got, tt.want) - } - }) - } -} - -func TestInferTypeFromContent(t *testing.T) { - tests := []struct { - subject string - body string - want string - }{ - {"fix the login bug", "", "fix"}, - {"resolve null pointer error", "", "fix"}, - {"add new user authentication", "", "feat"}, - {"implement OAuth2 flow", "", "feat"}, - {"update readme", "", "docs"}, - {"add unit tests", "", "test"}, - {"restructure the codebase", "", "refactor"}, - {"format code with gofmt", "", "style"}, - {"optimize query performance", "", "perf"}, - {"bump dependency version", "", "chore"}, - {"", "", "chore"}, - } - for _, tt := range tests { - t.Run(tt.subject, func(t *testing.T) { - got := inferTypeFromContent(tt.subject, tt.body) - if got != tt.want { - t.Errorf("inferTypeFromContent(%q, %q) = %q, want %q", tt.subject, tt.body, got, tt.want) - } - }) - } -} - -func TestSplitRuleValue(t *testing.T) { - tests := []struct { - input string - want int - }{ - {"2, always, [feat,fix]", 3}, - {"2, always", 2}, - {"2, always, 50", 3}, - {"single", 1}, - {"", 0}, - } - for _, tt := range tests { - t.Run(tt.input, func(t *testing.T) { - got := splitRuleValue(tt.input) - if len(got) != tt.want { - t.Errorf("splitRuleValue(%q) returned %d parts, want %d: %v", tt.input, len(got), tt.want, got) - } - }) - } -} - -func TestParseJSValue(t *testing.T) { - tests := []struct { - input string - want interface{} - }{ - {"['feat','fix','docs']", []string{"feat", "fix", "docs"}}, - {"42", 42}, - {"'lowercase'", "lowercase"}, - {`"lowercase"`, "lowercase"}, - } - for _, tt := range tests { - t.Run(tt.input, func(t *testing.T) { - got := parseJSValue(tt.input) - switch want := tt.want.(type) { - case []string: - gotSlice, ok := got.([]string) - if !ok { - t.Errorf("parseJSValue(%q) returned %T, want []string", tt.input, got) - return - } - if len(gotSlice) != len(want) { - t.Errorf("parseJSValue(%q) returned %d items, want %d", tt.input, len(gotSlice), len(want)) - } - case int: - gotInt, ok := got.(int) - if !ok { - t.Errorf("parseJSValue(%q) returned %T, want int", tt.input, got) - return - } - if gotInt != want { - t.Errorf("parseJSValue(%q) = %d, want %d", tt.input, gotInt, want) - } - case string: - gotStr, ok := got.(string) - if !ok { - t.Errorf("parseJSValue(%q) returned %T, want string", tt.input, got) - return - } - if gotStr != want { - t.Errorf("parseJSValue(%q) = %q, want %q", tt.input, gotStr, want) - } - } - }) - } -} - -func TestIsLowerCase(t *testing.T) { - tests := []struct { - input string - want bool - }{ - {"feat", true}, - {"Feat", false}, - {"FEAT", false}, - {"", true}, - {"fix123", true}, - {"FIX", false}, - } - for _, tt := range tests { - t.Run(tt.input, func(t *testing.T) { - got := isLowerCase(tt.input) - if got != tt.want { - t.Errorf("isLowerCase(%q) = %v, want %v", tt.input, got, tt.want) - } - }) - } -} - -func TestLoadFromProject_YAMLConfig(t *testing.T) { - dir := t.TempDir() - config := `extends: - - '@commitlint/config-conventional' -rules: - type-enum: [2, always, 'feat|fix|chore'] - subject-max-length: [2, always, 50] -` - err := os.WriteFile(filepath.Join(dir, ".commitlintrc.yml"), []byte(config), 0o644) - if err != nil { - t.Fatal(err) - } - - linter := NewCommitLinter() - if err := linter.LoadFromProject(dir); err != nil { - t.Fatalf("LoadFromProject() error: %v", err) - } - - // Verify subject-max-length was updated - for _, rule := range linter.Rules { - if rule.Name == "subject-max-length" { - v, ok := rule.Value.(int) - if !ok { - t.Fatalf("subject-max-length value is not int: %T", rule.Value) - } - if v != 50 { - t.Errorf("expected subject-max-length = 50, got %d", v) - } - break - } - } -} - -func TestLoadFromProject_YAMLAlternate(t *testing.T) { - dir := t.TempDir() - config := `rules: - header-max-length: [1, always, 80] -` - err := os.WriteFile(filepath.Join(dir, ".commitlintrc.yaml"), []byte(config), 0o644) - if err != nil { - t.Fatal(err) - } - - linter := NewCommitLinter() - if err := linter.LoadFromProject(dir); err != nil { - t.Fatalf("LoadFromProject() error: %v", err) - } - - // Verify header-max-length was updated - for _, rule := range linter.Rules { - if rule.Name == "header-max-length" { - if rule.Level != "warning" { - t.Errorf("expected header-max-length level = warning, got %q", rule.Level) - } - break - } - } -} - -func TestLint_UnknownRule(t *testing.T) { - linter := NewCommitLinter() - linter.Rules = append(linter.Rules, CommitRule{ - Name: "unknown-rule", - Level: "error", - Applicable: "always", - }) - // Unknown rules should be silently ignored - result := linter.Lint("feat: add login") - if !result.Valid { - t.Errorf("unknown rule should not cause validation failure: %v", result.Errors) - } -} - -func TestCheckSubjectMaxLength_Float64Value(t *testing.T) { - linter := NewCommitLinter() - // Override with float64 value (as JSON parsing might produce) - linter.Rules = []CommitRule{ - {Name: "subject-max-length", Level: "error", Applicable: "always", Value: float64(50)}, - } - msg := "feat: " + strings.Repeat("a", 60) - result := linter.Lint(msg) - found := false - for _, e := range result.Errors { - if strings.Contains(e, "subject-max-length") { - found = true - } - } - if !found { - t.Error("expected subject-max-length error with float64 value") - } -} - -func TestCheckBodyMaxLineLength_Float64Value(t *testing.T) { - linter := NewCommitLinter() - linter.Rules = []CommitRule{ - {Name: "body-max-line-length", Level: "warning", Applicable: "always", Value: float64(80)}, - } - longLine := strings.Repeat("x", 100) - msg := "feat: add\n\n" + longLine - result := linter.Lint(msg) - found := false - for _, w := range result.Warnings { - if strings.Contains(w, "body-max-line-length") { - found = true - } - } - if !found { - t.Error("expected body-max-line-length warning with float64 value") - } -} - -func TestCheckHeaderMaxLength_Float64Value(t *testing.T) { - linter := NewCommitLinter() - linter.Rules = []CommitRule{ - {Name: "header-max-length", Level: "error", Applicable: "always", Value: float64(50)}, - } - msg := "feat: " + strings.Repeat("b", 60) - result := linter.Lint(msg) - found := false - for _, e := range result.Errors { - if strings.Contains(e, "header-max-length") { - found = true - } - } - if !found { - t.Error("expected header-max-length error with float64 value") - } -} - -func TestCheckFooterMaxLineLength_Float64Value(t *testing.T) { - linter := NewCommitLinter() - linter.Rules = []CommitRule{ - {Name: "footer-max-line-length", Level: "warning", Applicable: "always", Value: float64(80)}, - } - longFooter := strings.Repeat("z", 100) - msg := "feat: add\n\nBody.\n\nRefs: #123\n" + longFooter - result := linter.Lint(msg) - found := false - for _, w := range result.Warnings { - if strings.Contains(w, "footer-max-line-length") { - found = true - } - } - if !found { - t.Error("expected footer-max-line-length warning with float64 value") - } -} - -func TestCheckTypeEnum_NeverApplicable(t *testing.T) { - linter := NewCommitLinter() - for i, rule := range linter.Rules { - if rule.Name == "type-enum" { - linter.Rules[i].Applicable = "never" - break - } - } - result := linter.Lint("feat: add login") - found := false - for _, e := range result.Errors { - if strings.Contains(e, "type-enum") { - found = true - } - } - if !found { - t.Error("expected type-enum error with 'never' applicable") - } -} - -func TestCheckTypeCase_EmptyCaseType(t *testing.T) { - linter := NewCommitLinter() - for i, rule := range linter.Rules { - if rule.Name == "type-case" { - linter.Rules[i].Value = "" - break - } - } - // Should default to lowercase check - result := linter.Lint("Feat: add something") - if result.Valid { - t.Error("expected invalid result for uppercase type with empty case type") - } -} - -func TestCheckScopeCase_Uppercase(t *testing.T) { - linter := NewCommitLinter() - for i, rule := range linter.Rules { - if rule.Name == "scope-case" { - linter.Rules[i].Value = "uppercase" - break - } - } - result := linter.Lint("feat(AUTH): add login") - found := false - for _, e := range result.Errors { - if strings.Contains(e, "scope-case") { - found = true - } - } - // With uppercase rule, lowercase scope should fail - if found { - t.Error("unexpected scope-case error for uppercase scope with uppercase rule") - } -} - -func TestCheckSubjectMaxLength_EmptyCaseType(t *testing.T) { - linter := NewCommitLinter() - for i, rule := range linter.Rules { - if rule.Name == "subject-max-length" { - linter.Rules[i].Value = nil - break - } - } - // Should use default 72 - result := linter.Lint("feat: " + strings.Repeat("a", 80)) - found := false - for _, e := range result.Errors { - if strings.Contains(e, "subject-max-length") { - found = true - } - } - if !found { - t.Error("expected subject-max-length error with nil value (default 72)") - } -} - -func TestParseCommitMessage_JustType(t *testing.T) { - parsed := ParseCommitMessage("feat:") - if parsed.Type != "feat" { - t.Errorf("Type = %q, want %q", parsed.Type, "feat") - } - if parsed.Subject != "" { - t.Errorf("Subject = %q, want empty", parsed.Subject) - } -} - -func TestParseCommitMessage_BreakingInBody(t *testing.T) { - msg := "feat: add auth\n\nImplement OAuth2.\n\nBREAKING CHANGE: removed legacy" - parsed := ParseCommitMessage(msg) - if !parsed.Breaking { - t.Error("expected Breaking = true for BREAKING CHANGE in body") - } -} - -func TestParseCommitMessage_BreakingDashChange(t *testing.T) { - msg := "feat: add auth\n\nBREAKING-CHANGE: removed legacy" - parsed := ParseCommitMessage(msg) - if !parsed.Breaking { - t.Error("expected Breaking = true for BREAKING-CHANGE") - } -} - -func TestFixMessage_WithScope(t *testing.T) { - linter := NewCommitLinter() - fixed := linter.FixMessage("Feat(Auth): add login") - if !strings.HasPrefix(fixed, "feat(auth):") { - t.Errorf("expected 'feat(auth):' prefix, got: %q", fixed) - } -} - -func TestFixMessage_LongHeader(t *testing.T) { - linter := NewCommitLinter() - long := "feat: " + strings.Repeat("a", 120) - fixed := linter.FixMessage(long) - if len(fixed) > 100 { - t.Errorf("expected header truncated to 100 chars, got %d: %q", len(fixed), fixed) - } -} - -func TestUpdateRule_NewRule(t *testing.T) { - linter := NewCommitLinter() - initialCount := len(linter.Rules) - linter.updateRule(CommitRule{ - Name: "custom-rule", - Level: "error", - Applicable: "always", - Value: "test", - }) - if len(linter.Rules) != initialCount+1 { - t.Errorf("expected %d rules, got %d", initialCount+1, len(linter.Rules)) - } -} - -func TestUpdateRule_OverrideExisting(t *testing.T) { - linter := NewCommitLinter() - linter.updateRule(CommitRule{ - Name: "type-enum", - Level: "warning", - Applicable: "never", - Value: []string{"a", "b"}, - }) - for _, rule := range linter.Rules { - if rule.Name == "type-enum" { - if rule.Level != "warning" { - t.Errorf("expected level 'warning', got %q", rule.Level) - } - if rule.Applicable != "never" { - t.Errorf("expected applicable 'never', got %q", rule.Applicable) - } - return - } - } - t.Error("type-enum rule not found") -} From e3e68b2dc66fc3f11a9607b6516489c11c1bc748 Mon Sep 17 00:00:00 2001 From: Lakshman Patel Date: Fri, 19 Jun 2026 20:43:37 +0530 Subject: [PATCH 20/20] test(tool): split RefactorTool.Execute tests into refactor_tool_test.go --- internal/tool/refactor_test.go | 386 --------------------------- internal/tool/refactor_tool_test.go | 394 ++++++++++++++++++++++++++++ 2 files changed, 394 insertions(+), 386 deletions(-) create mode 100644 internal/tool/refactor_tool_test.go diff --git a/internal/tool/refactor_test.go b/internal/tool/refactor_test.go index 771b7564..ca77052b 100644 --- a/internal/tool/refactor_test.go +++ b/internal/tool/refactor_test.go @@ -1,8 +1,6 @@ package tool import ( - "context" - "encoding/json" "os" "path/filepath" "strings" @@ -615,387 +613,3 @@ func TestFormatRefactoringResult_Nil(t *testing.T) { t.Errorf("expected 'No result', got %q", output) } } - -func TestRefactorTool_Interface(t *testing.T) { - tool := NewRefactorTool() - if tool.Name() != "Refactor" { - t.Errorf("expected name Refactor, got %s", tool.Name()) - } - if tool.Description() == "" { - t.Error("expected non-empty description") - } - params := tool.Parameters() - if params == nil { - t.Fatal("expected non-nil parameters") - } - props, ok := params["properties"].(map[string]interface{}) - if !ok { - t.Fatal("expected properties map") - } - if _, exists := props["action"]; !exists { - t.Error("expected action property") - } - if _, exists := props["file"]; !exists { - t.Error("expected file property") - } -} - -func TestRefactorTool_Execute_RenameSymbol(t *testing.T) { - dir := t.TempDir() - file := filepath.Join(dir, "main.go") - - content := `package main - -var counter = 0 - -func increment() { - counter++ -} -` - os.WriteFile(file, []byte(content), 0o644) - - tool := NewRefactorTool() - input, _ := json.Marshal(map[string]interface{}{ - "action": "rename_symbol", - "file": file, - "old_name": "counter", - "new_name": "count", - }) - - output, err := tool.Execute(context.Background(), input) - if err != nil { - t.Fatal(err) - } - if !strings.Contains(output, "rename_symbol") { - t.Error("expected rename_symbol in output") - } - - data, _ := os.ReadFile(file) - got := string(data) - if strings.Contains(got, "counter") { - t.Error("old name should be replaced") - } -} - -func TestRefactorTool_Execute_MissingAction(t *testing.T) { - tool := NewRefactorTool() - input, _ := json.Marshal(map[string]interface{}{ - "file": "/tmp/test.go", - }) - - _, err := tool.Execute(context.Background(), input) - if err == nil { - t.Fatal("expected error for missing action") - } -} - -func TestRefactorTool_Execute_MissingFile(t *testing.T) { - tool := NewRefactorTool() - input, _ := json.Marshal(map[string]interface{}{ - "action": "sort_imports", - }) - - _, err := tool.Execute(context.Background(), input) - if err == nil { - t.Fatal("expected error for missing file") - } -} - -func TestRefactorTool_Execute_UnknownAction(t *testing.T) { - tool := NewRefactorTool() - input, _ := json.Marshal(map[string]interface{}{ - "action": "unknown_action", - "file": "/tmp/test.go", - }) - - _, err := tool.Execute(context.Background(), input) - if err == nil { - t.Fatal("expected error for unknown action") - } -} - -func TestRefactorTool_Execute_SortImports(t *testing.T) { - dir := t.TempDir() - file := filepath.Join(dir, "main.go") - - content := `package main - -import ( - "os" - "fmt" -) - -func main() { - fmt.Println(os.Args) -} -` - os.WriteFile(file, []byte(content), 0o644) - - tool := NewRefactorTool() - input, _ := json.Marshal(map[string]interface{}{ - "action": "sort_imports", - "file": file, - }) - - output, err := tool.Execute(context.Background(), input) - if err != nil { - t.Fatal(err) - } - if !strings.Contains(output, "sort_imports") { - t.Error("expected sort_imports in output") - } -} - -func TestRefactorTool_Execute_ExtractFunction(t *testing.T) { - dir := t.TempDir() - file := filepath.Join(dir, "main.go") - - content := `package main - -import "fmt" - -func main() { - fmt.Println("a") - fmt.Println("b") -} -` - os.WriteFile(file, []byte(content), 0o644) - - tool := NewRefactorTool() - input, _ := json.Marshal(map[string]interface{}{ - "action": "extract_function", - "file": file, - "start_line": 6, - "end_line": 7, - "new_name": "printAB", - }) - - output, err := tool.Execute(context.Background(), input) - if err != nil { - t.Fatal(err) - } - if !strings.Contains(output, "extract_function") { - t.Error("expected extract_function in output") - } -} - -func TestRefactorTool_Execute_InlineVariable(t *testing.T) { - dir := t.TempDir() - file := filepath.Join(dir, "main.go") - - content := `package main - -import "fmt" - -func main() { - val := "test" - fmt.Println(val) -} -` - os.WriteFile(file, []byte(content), 0o644) - - tool := NewRefactorTool() - input, _ := json.Marshal(map[string]interface{}{ - "action": "inline_variable", - "file": file, - "line": 6, - }) - - output, err := tool.Execute(context.Background(), input) - if err != nil { - t.Fatal(err) - } - if !strings.Contains(output, "inline_variable") { - t.Error("expected inline_variable in output") - } -} - -func TestRefactorTool_Execute_ExtractVariable(t *testing.T) { - dir := t.TempDir() - file := filepath.Join(dir, "main.go") - - content := `package main - -import "fmt" - -func main() { - fmt.Println(1 + 1) -} -` - os.WriteFile(file, []byte(content), 0o644) - - tool := NewRefactorTool() - input, _ := json.Marshal(map[string]interface{}{ - "action": "extract_variable", - "file": file, - "line": 6, - "expr": "1 + 1", - "var_name": "result", - }) - - output, err := tool.Execute(context.Background(), input) - if err != nil { - t.Fatal(err) - } - if !strings.Contains(output, "extract_variable") { - t.Error("expected extract_variable in output") - } -} - -func TestRefactorTool_Execute_AddErrorCheck(t *testing.T) { - dir := t.TempDir() - file := filepath.Join(dir, "main.go") - - content := `package main - -func main() { - err := doThing() -} -` - os.WriteFile(file, []byte(content), 0o644) - - tool := NewRefactorTool() - input, _ := json.Marshal(map[string]interface{}{ - "action": "add_error_check", - "file": file, - "line": 4, - }) - - output, err := tool.Execute(context.Background(), input) - if err != nil { - t.Fatal(err) - } - if !strings.Contains(output, "add_error_check") { - t.Error("expected add_error_check in output") - } -} - -func TestRefactorTool_Execute_WrapWithContext(t *testing.T) { - dir := t.TempDir() - file := filepath.Join(dir, "main.go") - - content := `package main - -func work() error { - err := step() - return err -} -` - os.WriteFile(file, []byte(content), 0o644) - - tool := NewRefactorTool() - input, _ := json.Marshal(map[string]interface{}{ - "action": "wrap_with_context", - "file": file, - "line": 5, - "context": "work failed", - }) - - output, err := tool.Execute(context.Background(), input) - if err != nil { - t.Fatal(err) - } - if !strings.Contains(output, "wrap_with_context") { - t.Error("expected wrap_with_context in output") - } -} - -func TestRefactorTool_Execute_RemoveUnusedParams(t *testing.T) { - dir := t.TempDir() - file := filepath.Join(dir, "main.go") - - content := `package main - -func process(used int, unused string) int { - return used * 2 -} -` - os.WriteFile(file, []byte(content), 0o644) - - tool := NewRefactorTool() - input, _ := json.Marshal(map[string]interface{}{ - "action": "remove_unused_params", - "file": file, - "func_name": "process", - }) - - output, err := tool.Execute(context.Background(), input) - if err != nil { - t.Fatal(err) - } - if !strings.Contains(output, "remove_unused_params") { - t.Error("expected remove_unused_params in output") - } -} - -func TestRefactorTool_Execute_ConvertTableTest(t *testing.T) { - dir := t.TempDir() - file := filepath.Join(dir, "calc_test.go") - - content := `package main - -import "testing" - -func TestMultiply(t *testing.T) { - result := 3 * 4 - if result != 12 { - t.Fatal("expected 12") - } -} -` - os.WriteFile(file, []byte(content), 0o644) - - tool := NewRefactorTool() - input, _ := json.Marshal(map[string]interface{}{ - "action": "convert_table_test", - "file": file, - "test_func": "TestMultiply", - }) - - output, err := tool.Execute(context.Background(), input) - if err != nil { - t.Fatal(err) - } - if !strings.Contains(output, "convert_table_test") { - t.Error("expected convert_table_test in output") - } -} - -func TestParseParamList(t *testing.T) { - tests := []struct { - input string - expected int - }{ - {"a int, b string", 2}, - {"a, b int", 2}, - {"", 0}, - {"ctx context.Context, name string", 2}, - } - - for _, tt := range tests { - params := parseParamList(tt.input) - if len(params) != tt.expected { - t.Errorf("parseParamList(%q) returned %d params, expected %d", tt.input, len(params), tt.expected) - } - } -} - -func TestDetectParameters(t *testing.T) { - before := []string{ - "\tx := 10", - "\ty := 20", - "\tz := 30", - } - extracted := []string{ - "\tfmt.Println(x, y)", - } - - params := detectParameters(before, extracted) - if len(params) != 2 { - t.Fatalf("expected 2 params, got %d: %v", len(params), params) - } - // Should be sorted. - if params[0] != "x" || params[1] != "y" { - t.Errorf("expected [x, y], got %v", params) - } -} diff --git a/internal/tool/refactor_tool_test.go b/internal/tool/refactor_tool_test.go new file mode 100644 index 00000000..328edc9c --- /dev/null +++ b/internal/tool/refactor_tool_test.go @@ -0,0 +1,394 @@ +package tool + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestRefactorTool_Interface(t *testing.T) { + tool := NewRefactorTool() + if tool.Name() != "Refactor" { + t.Errorf("expected name Refactor, got %s", tool.Name()) + } + if tool.Description() == "" { + t.Error("expected non-empty description") + } + params := tool.Parameters() + if params == nil { + t.Fatal("expected non-nil parameters") + } + props, ok := params["properties"].(map[string]interface{}) + if !ok { + t.Fatal("expected properties map") + } + if _, exists := props["action"]; !exists { + t.Error("expected action property") + } + if _, exists := props["file"]; !exists { + t.Error("expected file property") + } +} + +func TestRefactorTool_Execute_RenameSymbol(t *testing.T) { + dir := t.TempDir() + file := filepath.Join(dir, "main.go") + + content := `package main + +var counter = 0 + +func increment() { + counter++ +} +` + os.WriteFile(file, []byte(content), 0o644) + + tool := NewRefactorTool() + input, _ := json.Marshal(map[string]interface{}{ + "action": "rename_symbol", + "file": file, + "old_name": "counter", + "new_name": "count", + }) + + output, err := tool.Execute(context.Background(), input) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(output, "rename_symbol") { + t.Error("expected rename_symbol in output") + } + + data, _ := os.ReadFile(file) + got := string(data) + if strings.Contains(got, "counter") { + t.Error("old name should be replaced") + } +} + +func TestRefactorTool_Execute_MissingAction(t *testing.T) { + tool := NewRefactorTool() + input, _ := json.Marshal(map[string]interface{}{ + "file": "/tmp/test.go", + }) + + _, err := tool.Execute(context.Background(), input) + if err == nil { + t.Fatal("expected error for missing action") + } +} + +func TestRefactorTool_Execute_MissingFile(t *testing.T) { + tool := NewRefactorTool() + input, _ := json.Marshal(map[string]interface{}{ + "action": "sort_imports", + }) + + _, err := tool.Execute(context.Background(), input) + if err == nil { + t.Fatal("expected error for missing file") + } +} + +func TestRefactorTool_Execute_UnknownAction(t *testing.T) { + tool := NewRefactorTool() + input, _ := json.Marshal(map[string]interface{}{ + "action": "unknown_action", + "file": "/tmp/test.go", + }) + + _, err := tool.Execute(context.Background(), input) + if err == nil { + t.Fatal("expected error for unknown action") + } +} + +func TestRefactorTool_Execute_SortImports(t *testing.T) { + dir := t.TempDir() + file := filepath.Join(dir, "main.go") + + content := `package main + +import ( + "os" + "fmt" +) + +func main() { + fmt.Println(os.Args) +} +` + os.WriteFile(file, []byte(content), 0o644) + + tool := NewRefactorTool() + input, _ := json.Marshal(map[string]interface{}{ + "action": "sort_imports", + "file": file, + }) + + output, err := tool.Execute(context.Background(), input) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(output, "sort_imports") { + t.Error("expected sort_imports in output") + } +} + +func TestRefactorTool_Execute_ExtractFunction(t *testing.T) { + dir := t.TempDir() + file := filepath.Join(dir, "main.go") + + content := `package main + +import "fmt" + +func main() { + fmt.Println("a") + fmt.Println("b") +} +` + os.WriteFile(file, []byte(content), 0o644) + + tool := NewRefactorTool() + input, _ := json.Marshal(map[string]interface{}{ + "action": "extract_function", + "file": file, + "start_line": 6, + "end_line": 7, + "new_name": "printAB", + }) + + output, err := tool.Execute(context.Background(), input) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(output, "extract_function") { + t.Error("expected extract_function in output") + } +} + +func TestRefactorTool_Execute_InlineVariable(t *testing.T) { + dir := t.TempDir() + file := filepath.Join(dir, "main.go") + + content := `package main + +import "fmt" + +func main() { + val := "test" + fmt.Println(val) +} +` + os.WriteFile(file, []byte(content), 0o644) + + tool := NewRefactorTool() + input, _ := json.Marshal(map[string]interface{}{ + "action": "inline_variable", + "file": file, + "line": 6, + }) + + output, err := tool.Execute(context.Background(), input) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(output, "inline_variable") { + t.Error("expected inline_variable in output") + } +} + +func TestRefactorTool_Execute_ExtractVariable(t *testing.T) { + dir := t.TempDir() + file := filepath.Join(dir, "main.go") + + content := `package main + +import "fmt" + +func main() { + fmt.Println(1 + 1) +} +` + os.WriteFile(file, []byte(content), 0o644) + + tool := NewRefactorTool() + input, _ := json.Marshal(map[string]interface{}{ + "action": "extract_variable", + "file": file, + "line": 6, + "expr": "1 + 1", + "var_name": "result", + }) + + output, err := tool.Execute(context.Background(), input) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(output, "extract_variable") { + t.Error("expected extract_variable in output") + } +} + +func TestRefactorTool_Execute_AddErrorCheck(t *testing.T) { + dir := t.TempDir() + file := filepath.Join(dir, "main.go") + + content := `package main + +func main() { + err := doThing() +} +` + os.WriteFile(file, []byte(content), 0o644) + + tool := NewRefactorTool() + input, _ := json.Marshal(map[string]interface{}{ + "action": "add_error_check", + "file": file, + "line": 4, + }) + + output, err := tool.Execute(context.Background(), input) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(output, "add_error_check") { + t.Error("expected add_error_check in output") + } +} + +func TestRefactorTool_Execute_WrapWithContext(t *testing.T) { + dir := t.TempDir() + file := filepath.Join(dir, "main.go") + + content := `package main + +func work() error { + err := step() + return err +} +` + os.WriteFile(file, []byte(content), 0o644) + + tool := NewRefactorTool() + input, _ := json.Marshal(map[string]interface{}{ + "action": "wrap_with_context", + "file": file, + "line": 5, + "context": "work failed", + }) + + output, err := tool.Execute(context.Background(), input) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(output, "wrap_with_context") { + t.Error("expected wrap_with_context in output") + } +} + +func TestRefactorTool_Execute_RemoveUnusedParams(t *testing.T) { + dir := t.TempDir() + file := filepath.Join(dir, "main.go") + + content := `package main + +func process(used int, unused string) int { + return used * 2 +} +` + os.WriteFile(file, []byte(content), 0o644) + + tool := NewRefactorTool() + input, _ := json.Marshal(map[string]interface{}{ + "action": "remove_unused_params", + "file": file, + "func_name": "process", + }) + + output, err := tool.Execute(context.Background(), input) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(output, "remove_unused_params") { + t.Error("expected remove_unused_params in output") + } +} + +func TestRefactorTool_Execute_ConvertTableTest(t *testing.T) { + dir := t.TempDir() + file := filepath.Join(dir, "calc_test.go") + + content := `package main + +import "testing" + +func TestMultiply(t *testing.T) { + result := 3 * 4 + if result != 12 { + t.Fatal("expected 12") + } +} +` + os.WriteFile(file, []byte(content), 0o644) + + tool := NewRefactorTool() + input, _ := json.Marshal(map[string]interface{}{ + "action": "convert_table_test", + "file": file, + "test_func": "TestMultiply", + }) + + output, err := tool.Execute(context.Background(), input) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(output, "convert_table_test") { + t.Error("expected convert_table_test in output") + } +} + +func TestParseParamList(t *testing.T) { + tests := []struct { + input string + expected int + }{ + {"a int, b string", 2}, + {"a, b int", 2}, + {"", 0}, + {"ctx context.Context, name string", 2}, + } + + for _, tt := range tests { + params := parseParamList(tt.input) + if len(params) != tt.expected { + t.Errorf("parseParamList(%q) returned %d params, expected %d", tt.input, len(params), tt.expected) + } + } +} + +func TestDetectParameters(t *testing.T) { + before := []string{ + "\tx := 10", + "\ty := 20", + "\tz := 30", + } + extracted := []string{ + "\tfmt.Println(x, y)", + } + + params := detectParameters(before, extracted) + if len(params) != 2 { + t.Fatalf("expected 2 params, got %d: %v", len(params), params) + } + // Should be sorted. + if params[0] != "x" || params[1] != "y" { + t.Errorf("expected [x, y], got %v", params) + } +}