package graph
import (
"context"
"github.com/google/uuid"
"github.com/siherrmann/grapher/model"
)
// GraphDB defines the interface for graph operations
type GraphDB interface {
GetChunk(ctx context.Context, id string) (*model.Chunk, error)
GetEdgesFromChunk(ctx context.Context, chunkID string, edgeTypes []model.EdgeType, followBidirectional bool) ([]*model.Edge, error)
}
// TraversalResult contains a chunk and its distance from the source
type TraversalResult struct {
Chunk *model.Chunk
Distance int
Path []uuid.UUID // Path from source to this chunk
}
// BFS performs breadth-first search from a source chunk
func BFS(ctx context.Context, db GraphDB, sourceID uuid.UUID, maxHops int, edgeTypes []model.EdgeType, followBidirectional bool) ([]*TraversalResult, error) {
visited := make(map[uuid.UUID]bool)
queue := []TraversalResult{{
Chunk: nil,
Distance: 0,
Path: []uuid.UUID{sourceID},
}}
// Get source chunk
sourceChunk, err := db.GetChunk(ctx, sourceID.String())
if err != nil {
return nil, err
}
queue[0].Chunk = sourceChunk
var results []*TraversalResult
visited[sourceID] = true
for len(queue) > 0 {
current := queue[0]
queue = queue[1:]
results = append(results, ¤t)
// Stop if we've reached max hops
if current.Distance >= maxHops {
continue
}
// Get edges from current chunk
edges, err := db.GetEdgesFromChunk(ctx, current.Chunk.ID.String(), edgeTypes, followBidirectional)
if err != nil {
return nil, err
}
// Process each edge
for _, edge := range edges {
var targetID uuid.UUID
// Determine target based on edge direction
if edge.SourceChunkID != nil && *edge.SourceChunkID == current.Chunk.ID && edge.TargetChunkID != nil {
targetID = *edge.TargetChunkID
} else if edge.Bidirectional && edge.TargetChunkID != nil && *edge.TargetChunkID == current.Chunk.ID && edge.SourceChunkID != nil {
targetID = *edge.SourceChunkID
} else {
continue // Skip entity edges or invalid edges
}
// Skip if already visited
if visited[targetID] {
continue
}
// Get target chunk
targetChunk, err := db.GetChunk(ctx, targetID.String())
if err != nil {
continue // Skip if chunk not found
}
visited[targetID] = true
// Create new path
newPath := make([]uuid.UUID, len(current.Path))
copy(newPath, current.Path)
newPath = append(newPath, targetID)
queue = append(queue, TraversalResult{
Chunk: targetChunk,
Distance: current.Distance + 1,
Path: newPath,
})
}
}
return results, nil
}
// DFS performs depth-first search from a source chunk
func DFS(ctx context.Context, db GraphDB, sourceID uuid.UUID, maxHops int, edgeTypes []model.EdgeType, followBidirectional bool) ([]*TraversalResult, error) {
visited := make(map[uuid.UUID]bool)
var results []*TraversalResult
// Get source chunk
sourceChunk, err := db.GetChunk(ctx, sourceID.String())
if err != nil {
return nil, err
}
// Start recursive DFS
dfsRecursive(ctx, db, sourceChunk, 0, maxHops, []uuid.UUID{sourceID}, edgeTypes, followBidirectional, visited, &results)
return results, nil
}
// dfsRecursive is the recursive helper for DFS
func dfsRecursive(
ctx context.Context,
db GraphDB,
current *model.Chunk,
distance int,
maxHops int,
path []uuid.UUID,
edgeTypes []model.EdgeType,
followBidirectional bool,
visited map[uuid.UUID]bool,
results *[]*TraversalResult,
) {
// Mark as visited
visited[current.ID] = true
// Add to results
pathCopy := make([]uuid.UUID, len(path))
copy(pathCopy, path)
*results = append(*results, &TraversalResult{
Chunk: current,
Distance: distance,
Path: pathCopy,
})
// Stop if we've reached max hops
if distance >= maxHops {
return
}
// Get edges from current chunk
edges, err := db.GetEdgesFromChunk(ctx, current.ID.String(), edgeTypes, followBidirectional)
if err != nil {
return
}
// Process each edge
for _, edge := range edges {
var targetID uuid.UUID
// Determine target based on edge direction
if edge.SourceChunkID != nil && *edge.SourceChunkID == current.ID && edge.TargetChunkID != nil {
targetID = *edge.TargetChunkID
} else if edge.Bidirectional && edge.TargetChunkID != nil && *edge.TargetChunkID == current.ID && edge.SourceChunkID != nil {
targetID = *edge.SourceChunkID
} else {
continue // Skip entity edges or invalid edges
}
// Skip if already visited
if visited[targetID] {
continue
}
// Get target chunk
targetChunk, err := db.GetChunk(ctx, targetID.String())
if err != nil {
continue // Skip if chunk not found
}
// Create new path
newPath := make([]uuid.UUID, len(path))
copy(newPath, path)
newPath = append(newPath, targetID)
// Recurse
dfsRecursive(ctx, db, targetChunk, distance+1, maxHops, newPath, edgeTypes, followBidirectional, visited, results)
}
}
// GetNeighbors retrieves immediate neighbors (1-hop) of a chunk
func GetNeighbors(ctx context.Context, db GraphDB, chunkID uuid.UUID, edgeTypes []model.EdgeType, followBidirectional bool) ([]*model.Chunk, error) {
results, err := BFS(ctx, db, chunkID, 1, edgeTypes, followBidirectional)
if err != nil {
return nil, err
}
// Skip the source chunk itself (first result)
neighbors := make([]*model.Chunk, 0, len(results)-1)
for i := 1; i < len(results); i++ {
neighbors = append(neighbors, results[i].Chunk)
}
return neighbors, nil
}
package pipeline
import (
"fmt"
"math"
"strings"
"github.com/knights-analytics/hugot"
"github.com/siherrmann/grapher/helper"
)
// SentenceChunker creates a chunker that splits by sentences
func SentenceChunker(maxSentencesPerChunk int) ChunkFunc {
return func(text string, basePath string) ([]ChunkWithPath, error) {
if maxSentencesPerChunk <= 0 {
return nil, fmt.Errorf("max sentences per chunk must be positive")
}
// Handle empty or whitespace-only text
if strings.TrimSpace(text) == "" {
return []ChunkWithPath{}, nil
}
text = strings.ReplaceAll(text, "! ", "!|")
text = strings.ReplaceAll(text, "? ", "?|")
text = strings.ReplaceAll(text, ". ", ".|")
sentences := strings.Split(text, "|")
var result []string
for _, s := range sentences {
s = strings.TrimSpace(s)
if s != "" {
result = append(result, s)
}
}
var chunks []ChunkWithPath
var currentChunk []string
chunkIdx := 0
pos := 0
for _, sentence := range sentences {
currentChunk = append(currentChunk, sentence)
if len(currentChunk) >= maxSentencesPerChunk {
content := strings.Join(currentChunk, " ")
startPos := pos
endPos := pos + len(content)
path := fmt.Sprintf("%s.chunk%d", basePath, chunkIdx)
chunks = append(chunks, ChunkWithPath{
Content: content,
Path: path,
StartPos: &startPos,
EndPos: &endPos,
ChunkIndex: &chunkIdx,
Metadata: make(map[string]interface{}),
})
pos = endPos
currentChunk = nil
chunkIdx++
}
}
// Add remaining sentences
if len(currentChunk) > 0 {
content := strings.Join(currentChunk, " ")
startPos := pos
endPos := pos + len(content)
path := fmt.Sprintf("%s.chunk%d", basePath, chunkIdx)
chunks = append(chunks, ChunkWithPath{
Content: content,
Path: path,
StartPos: &startPos,
EndPos: &endPos,
ChunkIndex: &chunkIdx,
Metadata: make(map[string]interface{}),
})
}
return chunks, nil
}
}
// ParagraphChunker creates a chunker that splits by paragraphs
func ParagraphChunker() ChunkFunc {
return func(text string, basePath string) ([]ChunkWithPath, error) {
paragraphs := strings.Split(text, "\n\n")
var chunks []ChunkWithPath
pos := 0
for i, para := range paragraphs {
para = strings.TrimSpace(para)
if para == "" {
continue
}
startPos := pos
endPos := pos + len(para)
path := fmt.Sprintf("%s.para%d", basePath, i)
chunks = append(chunks, ChunkWithPath{
Content: para,
Path: path,
StartPos: &startPos,
EndPos: &endPos,
ChunkIndex: &i,
Metadata: make(map[string]interface{}),
})
pos = endPos + 2 // Account for "\n\n"
}
return chunks, nil
}
}
// cosineSimilarity calculates the cosine similarity between two embedding vectors
func cosineSimilarity(a, b []float32) float32 {
if len(a) != len(b) {
return 0
}
var dotProduct, normA, normB float32
for i := range a {
dotProduct += a[i] * b[i]
normA += a[i] * a[i]
normB += b[i] * b[i]
}
if normA == 0 || normB == 0 {
return 0
}
return dotProduct / (float32(math.Sqrt(float64(normA))) * float32(math.Sqrt(float64(normB))))
}
// DefaultChunker creates a semantic chunker that uses embeddings to identify natural boundaries
// It analyzes semantic similarity between sentences and creates chunks at points where similarity drops
func DefaultChunker(maxChunkSize int, similarityThreshold float32) ChunkFunc {
return func(text string, basePath string) ([]ChunkWithPath, error) {
// Prepare model (download if needed)
modelName := "sentence-transformers/all-MiniLM-L6-v2"
modelPath, err := helper.PrepareModel(modelName, "onnx/model.onnx")
if err != nil {
return nil, err
}
// Initialize hugot session with Go backend
session, err := hugot.NewGoSession()
if err != nil {
return nil, fmt.Errorf("failed to create hugot session: %w", err)
}
defer session.Destroy()
// Create sentence transformers pipeline configuration
config := hugot.FeatureExtractionConfig{
ModelPath: modelPath,
Name: "semantic-chunker-pipeline",
OnnxFilename: "onnx/model.onnx",
}
sentencePipeline, err := hugot.NewPipeline(session, config)
if err != nil {
return nil, fmt.Errorf("failed to create sentence pipeline: %w", err)
}
// Split text into sentences
text = strings.ReplaceAll(text, "! ", "!|")
text = strings.ReplaceAll(text, "? ", "?|")
text = strings.ReplaceAll(text, ". ", ".|")
sentences := strings.Split(text, "|")
var cleanSentences []string
for _, s := range sentences {
s = strings.TrimSpace(s)
if s != "" {
cleanSentences = append(cleanSentences, s)
}
}
if len(cleanSentences) == 0 {
return nil, fmt.Errorf("no sentences found in text")
}
// Get embeddings for all sentences
embeddingResult, err := sentencePipeline.RunPipeline(cleanSentences)
if err != nil {
return nil, fmt.Errorf("failed to generate embeddings: %w", err)
}
embeddings := embeddingResult.Embeddings
if len(embeddings) != len(cleanSentences) {
return nil, fmt.Errorf("embedding count mismatch: got %d embeddings for %d sentences", len(embeddings), len(cleanSentences))
}
// Group sentences based on semantic similarity
var chunks []ChunkWithPath
var currentChunk []string
var currentEmbeddings [][]float32
var currentLength int
chunkIdx := 0
pos := 0
for i, sentence := range cleanSentences {
sentenceLen := len(sentence)
shouldBreak := false
// Check if we should create a chunk boundary
if len(currentChunk) > 0 {
// Calculate average embedding of current chunk
avgEmbedding := make([]float32, len(currentEmbeddings[0]))
for _, emb := range currentEmbeddings {
for j := range emb {
avgEmbedding[j] += emb[j]
}
}
for j := range avgEmbedding {
avgEmbedding[j] /= float32(len(currentEmbeddings))
}
// Calculate similarity between current chunk and new sentence
similarity := cosineSimilarity(avgEmbedding, embeddings[i])
// Break if similarity drops below threshold or size limit exceeded
if similarity < similarityThreshold || currentLength+sentenceLen > maxChunkSize {
shouldBreak = true
}
}
if shouldBreak {
// Create chunk
content := strings.Join(currentChunk, " ")
startPos := pos
endPos := pos + len(content)
path := fmt.Sprintf("%s.chunk%d", basePath, chunkIdx)
chunks = append(chunks, ChunkWithPath{
Content: content,
Path: path,
StartPos: &startPos,
EndPos: &endPos,
ChunkIndex: &chunkIdx,
Metadata: map[string]interface{}{
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
"num_sentences": len(currentChunk),
"chunking_method": "semantic",
},
})
pos = endPos
currentChunk = nil
currentEmbeddings = nil
currentLength = 0
chunkIdx++
}
currentChunk = append(currentChunk, sentence)
currentEmbeddings = append(currentEmbeddings, embeddings[i])
currentLength += sentenceLen
// For the last sentence, create the final chunk
if i == len(cleanSentences)-1 && len(currentChunk) > 0 {
content := strings.Join(currentChunk, " ")
startPos := pos
endPos := pos + len(content)
path := fmt.Sprintf("%s.chunk%d", basePath, chunkIdx)
chunks = append(chunks, ChunkWithPath{
Content: content,
Path: path,
StartPos: &startPos,
EndPos: &endPos,
ChunkIndex: &chunkIdx,
Metadata: map[string]interface{}{
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
"num_sentences": len(currentChunk),
"chunking_method": "semantic",
},
})
}
}
return chunks, nil
}
}
package pipeline
import (
"fmt"
"github.com/knights-analytics/hugot"
"github.com/siherrmann/grapher/helper"
)
// DefaultEmbedder creates an embedder using a real sentence transformer model
// Uses the all-MiniLM-L6-v2 model which produces 384-dimensional embeddings
func DefaultEmbedder() (EmbedFunc, error) {
// Prepare model (download if needed)
modelName := "sentence-transformers/all-MiniLM-L6-v2"
modelPath, err := helper.PrepareModel(modelName, "onnx/model.onnx")
if err != nil {
return nil, err
}
// Initialize hugot session with Go backend
session, err := hugot.NewGoSession()
if err != nil {
return nil, fmt.Errorf("failed to create hugot session: %w", err)
}
// Create sentence transformers pipeline configuration
config := hugot.FeatureExtractionConfig{
ModelPath: modelPath,
Name: "embedder-pipeline",
OnnxFilename: "onnx/model.onnx",
}
sentencePipeline, err := hugot.NewPipeline(session, config)
if err != nil {
if destroyErr := session.Destroy(); destroyErr != nil {
return nil, fmt.Errorf("failed to create sentence pipeline: %w (cleanup error: %v)", err, destroyErr)
}
return nil, fmt.Errorf("failed to create sentence pipeline: %w", err)
}
return func(text string) ([]float32, error) {
// Generate embedding for the text
result, err := sentencePipeline.RunPipeline([]string{text})
if err != nil {
return nil, fmt.Errorf("failed to generate embedding: %w", err)
}
if len(result.Embeddings) == 0 {
return nil, fmt.Errorf("no embedding generated")
}
// Extract the first (and only) embedding
embedding := result.Embeddings[0]
return embedding, nil
}, nil
}
package pipeline
import (
"fmt"
"strings"
"github.com/google/uuid"
"github.com/knights-analytics/hugot"
"github.com/knights-analytics/hugot/pipelines"
"github.com/siherrmann/grapher/helper"
"github.com/siherrmann/grapher/model"
)
// DefaultEntityExtractor creates an entity extractor using a NER model
// Uses distilbert-NER for named entity recognition
// Detects: PERSON, ORGANIZATION, LOCATION, MISC entities
func DefaultEntityExtractor() (EntityExtractFunc, error) {
// Prepare model (download if needed)
// Using KnightsAnalytics optimized distilbert-NER model
modelName := "KnightsAnalytics/distilbert-NER"
modelPath, err := helper.PrepareModel(modelName, "model.onnx")
if err != nil {
return nil, err
}
// Initialize hugot session with Go backend
session, err := hugot.NewGoSession()
if err != nil {
return nil, fmt.Errorf("failed to create hugot session: %w", err)
}
// Create token classification pipeline for NER
config := hugot.TokenClassificationConfig{
ModelPath: modelPath,
Name: "ner-pipeline",
Options: []hugot.TokenClassificationOption{
pipelines.WithSimpleAggregation(),
pipelines.WithIgnoreLabels([]string{"O"}), // Ignore non-entity tokens
},
}
nerPipeline, err := hugot.NewPipeline(session, config)
if err != nil {
if destroyErr := session.Destroy(); destroyErr != nil {
return nil, fmt.Errorf("failed to create NER pipeline: %w (cleanup error: %v)", err, destroyErr)
}
return nil, fmt.Errorf("failed to create NER pipeline: %w", err)
}
return func(text string) ([]*model.Entity, error) {
// Run NER on the text
result, err := nerPipeline.RunPipeline([]string{text})
if err != nil {
return nil, fmt.Errorf("failed to run NER: %w", err)
}
if len(result.Entities) == 0 {
return nil, nil
}
// Convert NER results to model.Entity
var entities []*model.Entity
for _, entity := range result.Entities[0] {
// Normalize entity type (remove B- and I- prefixes)
entityType := normalizeEntityType(entity.Entity)
entities = append(entities, &model.Entity{
ID: uuid.New(),
Name: strings.TrimSpace(entity.Word),
Type: entityType,
Metadata: map[string]interface{}{
"confidence": entity.Score,
"start": entity.Start,
"end": entity.End,
},
})
}
return entities, nil
}, nil
}
// normalizeEntityType removes B- and I- prefixes from NER labels
func normalizeEntityType(label string) string {
// Remove BIO tagging prefixes (B- for beginning, I- for inside)
if strings.HasPrefix(label, "B-") {
return label[2:]
}
if strings.HasPrefix(label, "I-") {
return label[2:]
}
return label
}
package pipeline
import "github.com/siherrmann/grapher/model"
// ChunkFunc is a function that splits text into chunks with their hierarchical paths
// The path should follow ltree format (e.g., "doc.chapter1.section2.chunk3")
type ChunkFunc func(text string, basePath string) ([]ChunkWithPath, error)
// EmbedFunc is a function that generates embeddings for text
type EmbedFunc func(text string) ([]float32, error)
// EntityExtractFunc extracts entities from text
// Returns a list of entities with their types and metadata
type EntityExtractFunc func(text string) ([]*model.Entity, error)
// RelationExtractFunc extracts relationships between entities or chunks
// Returns a list of edges representing the relationships
type RelationExtractFunc func(text string, chunkID string, entities []*model.Entity) ([]*model.Edge, error)
// ChunkWithPath represents a chunk with its hierarchical path
type ChunkWithPath struct {
Content string
Path string // ltree path
StartPos *int
EndPos *int
ChunkIndex *int
Metadata map[string]interface{}
}
// Pipeline combines chunking and embedding functions
type Pipeline struct {
Chunker ChunkFunc
Embedder EmbedFunc
EntityExtractor EntityExtractFunc // Optional
RelationExtractor RelationExtractFunc // Optional
}
// NewPipeline creates a new processing pipeline
func NewPipeline(chunker ChunkFunc, embedder EmbedFunc) *Pipeline {
return &Pipeline{
Chunker: chunker,
Embedder: embedder,
}
}
// SetEntityExtractor sets the entity extraction function
func (p *Pipeline) SetEntityExtractor(extractor EntityExtractFunc) {
p.EntityExtractor = extractor
}
// SetRelationExtractor sets the relation extraction function
func (p *Pipeline) SetRelationExtractor(extractor RelationExtractFunc) {
p.RelationExtractor = extractor
}
// ProcessingResult contains chunks and optionally extracted entities and relations
type ProcessingResult struct {
Chunks []*model.Chunk
Entities []*model.Entity
Relations []*model.Edge
}
// Process processes text through the pipeline, returning chunks with embeddings
func (p *Pipeline) Process(text string, basePath string) ([]*model.Chunk, error) {
result, err := p.ProcessWithExtraction(text, basePath)
if err != nil {
return nil, err
}
return result.Chunks, nil
}
// ProcessWithExtraction processes text and optionally extracts entities and relations
func (p *Pipeline) ProcessWithExtraction(text string, basePath string) (*ProcessingResult, error) {
// Split into chunks
chunksWithPath, err := p.Chunker(text, basePath)
if err != nil {
return nil, err
}
// Generate embeddings
chunks := make([]*model.Chunk, 0, len(chunksWithPath))
var allEntities []*model.Entity
var allRelations []*model.Edge
for _, cwp := range chunksWithPath {
embedding, err := p.Embedder(cwp.Content)
if err != nil {
return nil, err
}
chunk := &model.Chunk{
Content: cwp.Content,
Path: cwp.Path,
Embedding: embedding,
StartPos: cwp.StartPos,
EndPos: cwp.EndPos,
ChunkIndex: cwp.ChunkIndex,
Metadata: cwp.Metadata,
}
chunks = append(chunks, chunk)
// Extract entities if extractor is set
var chunkEntities []*model.Entity
if p.EntityExtractor != nil {
entities, err := p.EntityExtractor(cwp.Content)
if err == nil && entities != nil {
chunkEntities = entities
allEntities = append(allEntities, entities...)
}
}
// Extract relations if extractor is set
if p.RelationExtractor != nil {
relations, err := p.RelationExtractor(cwp.Content, cwp.Path, chunkEntities)
if err == nil && relations != nil {
allRelations = append(allRelations, relations...)
}
}
}
return &ProcessingResult{
Chunks: chunks,
Entities: allEntities,
Relations: allRelations,
}, nil
}
package pipeline
import (
"fmt"
"math"
"regexp"
"strings"
"github.com/knights-analytics/hugot"
"github.com/knights-analytics/hugot/pipelines"
"github.com/siherrmann/grapher/helper"
"github.com/siherrmann/grapher/model"
)
// DefaultRelationExtractor creates a relation extractor using NER models
// Uses token classification to detect citation-related entities and references
// Detects: Citations, references, and relationships between entities
func DefaultRelationExtractor() (RelationExtractFunc, error) {
// Prepare citation detection model (using NER to detect citation entities)
modelName := "KnightsAnalytics/distilbert-NER"
modelPath, err := helper.PrepareModel(modelName, "model.onnx")
if err != nil {
return nil, err
}
// Initialize hugot session with Go backend
session, err := hugot.NewGoSession()
if err != nil {
return nil, fmt.Errorf("failed to create hugot session: %w", err)
}
// Create token classification pipeline for citation detection
config := hugot.TokenClassificationConfig{
ModelPath: modelPath,
Name: "citation-pipeline",
Options: []hugot.TokenClassificationOption{
pipelines.WithSimpleAggregation(),
pipelines.WithIgnoreLabels([]string{"O"}),
},
}
citationPipeline, err := hugot.NewPipeline(session, config)
if err != nil {
if destroyErr := session.Destroy(); destroyErr != nil {
return nil, fmt.Errorf("failed to create citation pipeline: %w (cleanup error: %v)", err, destroyErr)
}
return nil, fmt.Errorf("failed to create citation pipeline: %w", err)
}
// Fallback patterns for specific citation formats
citationPatterns := []*regexp.Regexp{
regexp.MustCompile(`\[(\d+)\]`), // [1], [2]
regexp.MustCompile(`\(([A-Z][a-z]+(?:\s+et\s+al\.)?)\s+(\d{4})\)`), // (Smith 2020)
regexp.MustCompile(`(?i)\b(?:section|chapter)\s+(\d+(?:\.\d+)*)\b`), // section 3.2, chapter 5
regexp.MustCompile(`\b(?:Section|Chapter|Figure|Table)\s+(\d+(?:\.\d+)*)\b`), // Section 3, Figure 2.1
regexp.MustCompile(`doi:\s*(\S+)`), // DOI
regexp.MustCompile(`https?://\S+`), // URLs
}
return func(text string, chunkPath string, entities []*model.Entity) ([]*model.Edge, error) {
var edges []*model.Edge
// Use NER model to detect citation-related entities
// The model can detect MISC (miscellaneous) entities which often include citations
result, err := citationPipeline.RunPipeline([]string{text})
if err == nil && len(result.Entities) > 0 {
for _, entity := range result.Entities[0] {
// Look for entities that might be citations
// MISC entities from NER often capture numbers, dates, and references
entityType := strings.TrimPrefix(strings.TrimPrefix(entity.Entity, "B-"), "I-")
// Create reference edges for detected citation-like entities
if entityType == "MISC" || entityType == "PER" {
edge := &model.Edge{
EdgeType: model.EdgeTypeReference,
Weight: float64(entity.Score),
Bidirectional: false,
Metadata: map[string]interface{}{
"citation_text": entity.Word,
"detection_type": "ner_model",
"entity_type": entityType,
"confidence": entity.Score,
"extracted_from": chunkPath,
"start": entity.Start,
"end": entity.End,
},
}
edges = append(edges, edge)
}
}
}
// Use pattern matching as supplementary detection for structured citations
for _, pattern := range citationPatterns {
matches := pattern.FindAllStringSubmatch(text, -1)
for _, match := range matches {
edge := &model.Edge{
EdgeType: model.EdgeTypeReference,
Weight: 0.7, // Slightly lower weight for pattern-based
Bidirectional: false,
Metadata: map[string]interface{}{
"citation_text": match[0],
"detection_type": "pattern_supplement",
"citation_pattern": getCitationPatternType(pattern),
"extracted_from": chunkPath,
},
}
if len(match) > 1 {
edge.Metadata["reference_id"] = match[1]
if len(match) > 2 {
edge.Metadata["reference_year"] = match[2]
}
}
edges = append(edges, edge)
}
}
// Detect co-occurrence relationships between entities
if len(entities) > 1 {
for i := 0; i < len(entities); i++ {
for j := i + 1; j < len(entities); j++ {
entity1 := entities[i]
entity2 := entities[j]
// Get positions from metadata
start1, ok1 := entity1.Metadata["start"].(uint)
start2, ok2 := entity2.Metadata["start"].(uint)
if ok1 && ok2 {
// Clamp start positions to prevent overflow
clampedStart1 := start1
if start1 > math.MaxInt {
clampedStart1 = uint(math.MaxInt)
}
clampedStart2 := start2
if start2 > math.MaxInt {
clampedStart2 = uint(math.MaxInt)
}
// Calculate distance with clamped values
var distance int
if clampedStart2 > clampedStart1 {
// #nosec G115
distance = int(clampedStart2 - clampedStart1)
} else {
// #nosec G115
distance = int(clampedStart1 - clampedStart2)
}
// If entities are within 100 characters, create an entity mention edge
if distance < 100 {
edge := &model.Edge{
SourceEntityID: &entity1.ID,
TargetEntityID: &entity2.ID,
EdgeType: model.EdgeTypeEntityMention,
Weight: calculateCoOccurrenceWeight(distance),
Bidirectional: true,
Metadata: map[string]interface{}{
"distance": distance,
"context": chunkPath,
"entity1_type": entity1.Type,
"entity2_type": entity2.Type,
"entity1_name": entity1.Name,
"entity2_name": entity2.Name,
},
}
edges = append(edges, edge)
}
}
}
}
}
return edges, nil
}, nil
}
// getCitationPatternType returns the type of citation pattern
func getCitationPatternType(pattern *regexp.Regexp) string {
patternStr := pattern.String()
switch {
case strings.Contains(patternStr, `\[(\d+)\]`):
return "numeric_citation"
case strings.Contains(patternStr, "et\\s+al"):
return "author_year_citation"
case strings.Contains(patternStr, "section|chapter") && strings.Contains(patternStr, "(?i)"):
return "section_reference"
case strings.Contains(patternStr, "Section|Chapter|Figure|Table"):
return "inline_section_reference"
case strings.Contains(patternStr, "doi"):
return "doi_reference"
case strings.Contains(patternStr, "https?"):
return "url_reference"
default:
return "other"
}
}
// calculateCoOccurrenceWeight calculates edge weight based on entity proximity
// Closer entities get higher weights (stronger relationship)
func calculateCoOccurrenceWeight(distance int) float64 {
// Max weight is 1.0 for adjacent entities, decreasing with distance
// Formula: 1.0 - (distance / 200)
// At distance 0: weight = 1.0
// At distance 100: weight = 0.5
// At distance 200+: weight = 0.0
weight := 1.0 - (float64(distance) / 200.0)
if weight < 0 {
return 0.0
}
return weight
}
package retrieval
import (
"context"
"github.com/google/uuid"
"github.com/siherrmann/grapher/database"
"github.com/siherrmann/grapher/model"
)
// Engine provides hybrid retrieval and graph traversal capabilities
type Engine struct {
chunks *database.ChunksDBHandler
edges *database.EdgesDBHandler
entities *database.EntitiesDBHandler
}
// NewEngine creates a new retrieval engine
func NewEngine(chunks *database.ChunksDBHandler, edges *database.EdgesDBHandler, entities *database.EntitiesDBHandler) *Engine {
return &Engine{
chunks: chunks,
edges: edges,
entities: entities,
}
}
// VectorRetrieve performs pure vector similarity search
func (e *Engine) VectorRetrieve(ctx context.Context, embedding []float32, config *model.QueryConfig) ([]*model.RetrievalResult, error) {
chunks, err := e.chunks.SelectChunksBySimilarity(embedding, config.TopK, config.SimilarityThreshold, config.DocumentRIDs)
if err != nil {
return nil, err
}
results := make([]*model.RetrievalResult, len(chunks))
for i, chunk := range chunks {
score := 0.0
if chunk.Similarity != nil {
score = *chunk.Similarity
}
results[i] = &model.RetrievalResult{
Chunk: chunk,
Score: score,
SimilarityScore: score,
GraphDistance: 0,
RetrievalMethod: "vector",
}
}
return results, nil
}
// GetNeighbors retrieves immediate neighbors of a chunk
func (e *Engine) GetNeighbors(ctx context.Context, chunkID uuid.UUID, edgeTypes []model.EdgeType, followBidirectional bool) ([]*model.Chunk, error) {
allEdges, err := e.edges.SelectEdgesFromChunk(chunkID, nil)
if err != nil {
return nil, err
}
// Filter by edge types if specified
var edges []*model.Edge
if len(edgeTypes) == 0 {
edges = allEdges
} else {
for _, edge := range allEdges {
for _, edgeType := range edgeTypes {
if edge.EdgeType == edgeType {
edges = append(edges, edge)
break
}
}
}
}
var neighbors []*model.Chunk
visited := make(map[uuid.UUID]bool)
for _, edge := range edges {
var targetID uuid.UUID
// Determine target based on edge direction
if edge.SourceChunkID != nil && *edge.SourceChunkID == chunkID && edge.TargetChunkID != nil {
targetID = *edge.TargetChunkID
} else if edge.Bidirectional && edge.TargetChunkID != nil && *edge.TargetChunkID == chunkID && edge.SourceChunkID != nil {
targetID = *edge.SourceChunkID
} else {
continue
}
// Skip duplicates
if visited[targetID] {
continue
}
visited[targetID] = true
// Get chunk
chunk, err := e.chunks.SelectChunk(targetID)
if err != nil {
continue
}
neighbors = append(neighbors, chunk)
}
return neighbors, nil
}
// GetHierarchicalContext retrieves hierarchical context using ltree
func (e *Engine) GetHierarchicalContext(ctx context.Context, path string, config *model.QueryConfig) ([]*model.Chunk, error) {
var allChunks []*model.Chunk
if config.IncludeAncestors {
chunks, err := e.chunks.SelectAllChunksByPathAncestor(path)
if err == nil {
allChunks = append(allChunks, chunks...)
}
}
if config.IncludeDescendants {
chunks, err := e.chunks.SelectAllChunksByPathDescendant(path)
if err == nil {
allChunks = append(allChunks, chunks...)
}
}
if config.IncludeSiblings {
chunks, err := e.chunks.SelectSiblingChunks(path)
if err == nil {
allChunks = append(allChunks, chunks...)
}
}
return allChunks, nil
}
// TraversalResult contains a chunk and its distance from the source
type TraversalResult struct {
Chunk *model.Chunk
Distance int
Path []uuid.UUID // Path from source to this chunk
}
// BFS performs breadth-first search from a source chunk
func (e *Engine) BFS(ctx context.Context, sourceID uuid.UUID, maxHops int, edgeTypes []model.EdgeType, followBidirectional bool) ([]*TraversalResult, error) {
visited := make(map[uuid.UUID]bool)
queue := []TraversalResult{{
Chunk: nil,
Distance: 0,
Path: []uuid.UUID{sourceID},
}}
// Get source chunk
sourceChunk, err := e.chunks.SelectChunk(sourceID)
if err != nil {
return nil, err
}
queue[0].Chunk = sourceChunk
var results []*TraversalResult
visited[sourceID] = true
for len(queue) > 0 {
current := queue[0]
queue = queue[1:]
results = append(results, ¤t)
// Stop if we've reached max hops
if current.Distance >= maxHops {
continue
}
// Get edges from current chunk
allEdges, err := e.edges.SelectEdgesFromChunk(current.Chunk.ID, nil)
if err != nil {
return nil, err
}
// Filter by edge types if specified
var edges []*model.Edge
if len(edgeTypes) == 0 {
edges = allEdges
} else {
for _, edge := range allEdges {
for _, edgeType := range edgeTypes {
if edge.EdgeType == edgeType {
edges = append(edges, edge)
break
}
}
}
}
// Process each edge
for _, edge := range edges {
var targetID uuid.UUID
// Determine target based on edge direction
if edge.SourceChunkID != nil && *edge.SourceChunkID == current.Chunk.ID && edge.TargetChunkID != nil {
targetID = *edge.TargetChunkID
} else if edge.Bidirectional && edge.TargetChunkID != nil && *edge.TargetChunkID == current.Chunk.ID && edge.SourceChunkID != nil {
targetID = *edge.SourceChunkID
} else {
continue // Skip entity edges or invalid edges
}
// Skip if already visited
if visited[targetID] {
continue
}
// Get target chunk
targetChunk, err := e.chunks.SelectChunk(targetID)
if err != nil {
continue // Skip if chunk not found
}
visited[targetID] = true
// Create new path
newPath := make([]uuid.UUID, len(current.Path))
copy(newPath, current.Path)
newPath = append(newPath, targetID)
queue = append(queue, TraversalResult{
Chunk: targetChunk,
Distance: current.Distance + 1,
Path: newPath,
})
}
}
return results, nil
}
// DFS performs depth-first search from a source chunk
func (e *Engine) DFS(ctx context.Context, sourceID uuid.UUID, maxHops int, edgeTypes []model.EdgeType, followBidirectional bool) ([]*TraversalResult, error) {
visited := make(map[uuid.UUID]bool)
var results []*TraversalResult
// Get source chunk
sourceChunk, err := e.chunks.SelectChunk(sourceID)
if err != nil {
return nil, err
}
// Start recursive DFS
e.dfsRecursive(ctx, sourceChunk, 0, maxHops, []uuid.UUID{sourceID}, edgeTypes, followBidirectional, visited, &results)
return results, nil
}
// dfsRecursive is the recursive helper for DFS
func (e *Engine) dfsRecursive(
ctx context.Context,
current *model.Chunk,
distance int,
maxHops int,
path []uuid.UUID,
edgeTypes []model.EdgeType,
followBidirectional bool,
visited map[uuid.UUID]bool,
results *[]*TraversalResult,
) {
// Mark as visited
visited[current.ID] = true
// Add to results
pathCopy := make([]uuid.UUID, len(path))
copy(pathCopy, path)
*results = append(*results, &TraversalResult{
Chunk: current,
Distance: distance,
Path: pathCopy,
})
// Stop if we've reached max hops
if distance >= maxHops {
return
}
// Get edges from current chunk
allEdges, err := e.edges.SelectEdgesFromChunk(current.ID, nil)
if err != nil {
return
}
// Filter by edge types if specified
var edges []*model.Edge
if len(edgeTypes) == 0 {
edges = allEdges
} else {
for _, edge := range allEdges {
for _, edgeType := range edgeTypes {
if edge.EdgeType == edgeType {
edges = append(edges, edge)
break
}
}
}
}
// Process each edge
for _, edge := range edges {
var targetID uuid.UUID
// Determine target based on edge direction
if edge.SourceChunkID != nil && *edge.SourceChunkID == current.ID && edge.TargetChunkID != nil {
targetID = *edge.TargetChunkID
} else if edge.Bidirectional && edge.TargetChunkID != nil && *edge.TargetChunkID == current.ID && edge.SourceChunkID != nil {
targetID = *edge.SourceChunkID
} else {
continue // Skip entity edges or invalid edges
}
// Skip if already visited
if visited[targetID] {
continue
}
// Get target chunk
targetChunk, err := e.chunks.SelectChunk(targetID)
if err != nil {
continue // Skip if chunk not found
}
// Create new path
newPath := make([]uuid.UUID, len(path))
copy(newPath, path)
newPath = append(newPath, targetID)
// Recurse
e.dfsRecursive(ctx, targetChunk, distance+1, maxHops, newPath, edgeTypes, followBidirectional, visited, results)
}
}
package retrieval
import (
"context"
"sort"
"github.com/google/uuid"
"github.com/siherrmann/grapher/model"
)
// Strategy defines a retrieval strategy
type Strategy interface {
Retrieve(ctx context.Context, embedding []float32, config *model.QueryConfig) ([]*model.RetrievalResult, error)
}
// VectorOnlyStrategy performs pure vector similarity search
type VectorOnlyStrategy struct {
engine *Engine
}
// NewVectorOnlyStrategy creates a new vector-only strategy
func NewVectorOnlyStrategy(engine *Engine) *VectorOnlyStrategy {
return &VectorOnlyStrategy{engine: engine}
}
// Retrieve performs vector-only retrieval
func (s *VectorOnlyStrategy) Retrieve(ctx context.Context, embedding []float32, config *model.QueryConfig) ([]*model.RetrievalResult, error) {
return s.engine.VectorRetrieve(ctx, embedding, config)
}
// ContextualStrategy combines vector search with immediate neighbors and hierarchical context
type ContextualStrategy struct {
engine *Engine
}
// NewContextualStrategy creates a new contextual strategy
func NewContextualStrategy(engine *Engine) *ContextualStrategy {
return &ContextualStrategy{engine: engine}
}
// Retrieve performs contextual retrieval
func (s *ContextualStrategy) Retrieve(ctx context.Context, embedding []float32, config *model.QueryConfig) ([]*model.RetrievalResult, error) {
// First, get top-k similar chunks
vectorResults, err := s.engine.VectorRetrieve(ctx, embedding, config)
if err != nil {
return nil, err
}
resultMap := make(map[string]*model.RetrievalResult)
// Add vector results
for _, result := range vectorResults {
resultMap[result.Chunk.ID.String()] = result
}
// For each vector result, add neighbors and hierarchical context
for _, result := range vectorResults {
// Get neighbors
neighbors, err := s.engine.GetNeighbors(ctx, result.Chunk.ID, config.EdgeTypes, config.FollowBidirectional)
if err != nil {
continue
}
for _, neighbor := range neighbors {
if _, exists := resultMap[neighbor.ID.String()]; !exists {
resultMap[neighbor.ID.String()] = &model.RetrievalResult{
Chunk: neighbor,
Score: result.Score * config.GraphWeight,
SimilarityScore: 0,
GraphDistance: 1,
RetrievalMethod: "graph_neighbor",
}
}
}
// Get hierarchical context
hierarchicalChunks, err := s.engine.GetHierarchicalContext(ctx, result.Chunk.Path, config)
if err != nil {
continue
}
for _, hChunk := range hierarchicalChunks {
if _, exists := resultMap[hChunk.ID.String()]; !exists {
resultMap[hChunk.ID.String()] = &model.RetrievalResult{
Chunk: hChunk,
Score: result.Score * config.HierarchyWeight,
SimilarityScore: 0,
GraphDistance: 0,
RetrievalMethod: "hierarchical",
}
}
}
}
// Convert map to slice
results := make([]*model.RetrievalResult, 0, len(resultMap))
for _, result := range resultMap {
results = append(results, result)
}
// Sort by score
sort.Slice(results, func(i, j int) bool {
return results[i].Score > results[j].Score
})
return results, nil
}
// MultiHopStrategy performs graph traversal from top vector results
type MultiHopStrategy struct {
engine *Engine
}
// NewMultiHopStrategy creates a new multi-hop strategy
func NewMultiHopStrategy(engine *Engine) *MultiHopStrategy {
return &MultiHopStrategy{
engine: engine,
}
}
// Retrieve performs multi-hop retrieval
func (s *MultiHopStrategy) Retrieve(ctx context.Context, embedding []float32, config *model.QueryConfig) ([]*model.RetrievalResult, error) {
// First, get top-k similar chunks as starting points
vectorResults, err := s.engine.VectorRetrieve(ctx, embedding, config)
if err != nil {
return nil, err
}
resultMap := make(map[string]*model.RetrievalResult)
// Add vector results
for _, result := range vectorResults {
resultMap[result.Chunk.ID.String()] = result
}
// For each starting point, perform BFS/DFS
for _, result := range vectorResults {
traversalResults, err := s.engine.BFS(
ctx,
result.Chunk.ID,
config.MaxHops,
config.EdgeTypes,
config.FollowBidirectional,
)
if err != nil {
continue
}
for _, tResult := range traversalResults {
// Skip the source chunk (already in results)
if tResult.Distance == 0 {
continue
}
chunkIDStr := tResult.Chunk.ID.String()
if _, exists := resultMap[chunkIDStr]; !exists {
// Calculate score based on distance and original similarity
score := result.Score * config.GraphWeight / float64(tResult.Distance+1)
resultMap[chunkIDStr] = &model.RetrievalResult{
Chunk: tResult.Chunk,
Score: score,
SimilarityScore: 0,
GraphDistance: tResult.Distance,
RetrievalMethod: "multi_hop",
}
}
}
}
// Convert map to slice
results := make([]*model.RetrievalResult, 0, len(resultMap))
for _, result := range resultMap {
results = append(results, result)
}
// Sort by score
sort.Slice(results, func(i, j int) bool {
return results[i].Score > results[j].Score
})
return results, nil
}
// HybridStrategy combines vector, graph, and hierarchical signals with configurable weights
type HybridStrategy struct {
engine *Engine
}
// NewHybridStrategy creates a new hybrid strategy
func NewHybridStrategy(engine *Engine) *HybridStrategy {
return &HybridStrategy{
engine: engine,
}
}
// Retrieve performs hybrid retrieval with weighted combination
func (s *HybridStrategy) Retrieve(ctx context.Context, embedding []float32, config *model.QueryConfig) ([]*model.RetrievalResult, error) {
// Get vector results
vectorResults, err := s.engine.VectorRetrieve(ctx, embedding, config)
if err != nil {
return nil, err
}
resultMap := make(map[string]*model.RetrievalResult)
// Process each vector result
for _, vResult := range vectorResults {
chunkIDStr := vResult.Chunk.ID.String()
// Initialize with vector score
score := vResult.SimilarityScore * config.VectorWeight
resultMap[chunkIDStr] = &model.RetrievalResult{
Chunk: vResult.Chunk,
Score: score,
SimilarityScore: vResult.SimilarityScore,
GraphDistance: 0,
RetrievalMethod: "hybrid",
}
// Add graph neighbors
if config.MaxHops > 0 {
traversalResults, err := s.engine.BFS(
ctx,
vResult.Chunk.ID,
config.MaxHops,
config.EdgeTypes,
config.FollowBidirectional,
)
if err == nil {
for _, tResult := range traversalResults {
tChunkIDStr := tResult.Chunk.ID.String()
if existing, exists := resultMap[tChunkIDStr]; exists {
// Update score with graph component
if tResult.Distance > 0 {
graphScore := config.GraphWeight / float64(tResult.Distance)
existing.Score += graphScore
}
} else if tResult.Distance > 0 {
// New chunk from graph traversal
graphScore := config.GraphWeight / float64(tResult.Distance)
resultMap[tChunkIDStr] = &model.RetrievalResult{
Chunk: tResult.Chunk,
Score: graphScore,
SimilarityScore: 0,
GraphDistance: tResult.Distance,
RetrievalMethod: "hybrid",
}
}
}
}
}
// Add hierarchical context
if config.IncludeAncestors || config.IncludeDescendants || config.IncludeSiblings {
hierarchicalChunks, err := s.engine.GetHierarchicalContext(ctx, vResult.Chunk.Path, config)
if err == nil {
for _, hChunk := range hierarchicalChunks {
hChunkIDStr := hChunk.ID.String()
if existing, exists := resultMap[hChunkIDStr]; exists {
// Update score with hierarchy component
existing.Score += config.HierarchyWeight
} else {
// New chunk from hierarchical context
resultMap[hChunkIDStr] = &model.RetrievalResult{
Chunk: hChunk,
Score: config.HierarchyWeight,
SimilarityScore: 0,
GraphDistance: 0,
RetrievalMethod: "hybrid",
}
}
}
}
}
}
// Convert map to slice
results := make([]*model.RetrievalResult, 0, len(resultMap))
for _, result := range resultMap {
results = append(results, result)
}
// Sort by combined score
sort.Slice(results, func(i, j int) bool {
return results[i].Score > results[j].Score
})
// Limit to top-k
if len(results) > config.TopK {
results = results[:config.TopK]
}
return results, nil
}
// EntityCentricStrategy retrieves all chunks related to specific entities
type EntityCentricStrategy struct {
engine *Engine
entitiesDB EntitiesDB
}
// EntitiesDB defines the interface for entity operations
type EntitiesDB interface {
GetEntity(ctx context.Context, id string) (*model.Entity, error)
GetChunksForEntity(ctx context.Context, entityID string) ([]*model.Chunk, error)
}
// NewEntityCentricStrategy creates a new entity-centric strategy
func NewEntityCentricStrategy(engine *Engine, entitiesDB EntitiesDB) *EntityCentricStrategy {
return &EntityCentricStrategy{
engine: engine,
entitiesDB: entitiesDB,
}
}
// Retrieve performs entity-centric retrieval
func (s *EntityCentricStrategy) Retrieve(ctx context.Context, entityID uuid.UUID, config *model.QueryConfig) ([]*model.RetrievalResult, error) {
resultMap := make(map[string]*model.RetrievalResult)
// Get chunks directly linked to the entity
chunks, err := s.entitiesDB.GetChunksForEntity(ctx, entityID.String())
if err != nil {
return nil, err
}
// Add all chunks related to the entity
for _, chunk := range chunks {
resultMap[chunk.ID.String()] = &model.RetrievalResult{
Chunk: chunk,
Score: 1.0,
SimilarityScore: 0,
GraphDistance: 0,
RetrievalMethod: "entity_centric",
}
}
// Optionally expand via graph traversal
if config.MaxHops > 0 {
for _, chunk := range chunks {
traversalResults, err := s.engine.BFS(
ctx,
chunk.ID,
config.MaxHops,
config.EdgeTypes,
config.FollowBidirectional,
)
if err != nil {
continue
}
for _, tResult := range traversalResults {
if tResult.Distance == 0 {
continue // Skip source
}
tChunkIDStr := tResult.Chunk.ID.String()
if _, exists := resultMap[tChunkIDStr]; !exists {
score := config.GraphWeight / float64(tResult.Distance)
resultMap[tChunkIDStr] = &model.RetrievalResult{
Chunk: tResult.Chunk,
Score: score,
SimilarityScore: 0,
GraphDistance: tResult.Distance,
RetrievalMethod: "entity_fanout",
}
}
}
}
}
// Convert to slice
results := make([]*model.RetrievalResult, 0, len(resultMap))
for _, result := range resultMap {
results = append(results, result)
}
// Sort by score
sort.Slice(results, func(i, j int) bool {
return results[i].Score > results[j].Score
})
// Limit to top-k if specified
if config.TopK > 0 && len(results) > config.TopK {
results = results[:config.TopK]
}
return results, nil
}
package database
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"log"
"time"
"github.com/google/uuid"
"github.com/lib/pq"
"github.com/pgvector/pgvector-go"
"github.com/siherrmann/grapher/helper"
"github.com/siherrmann/grapher/model"
loadSql "github.com/siherrmann/grapher/sql"
)
// ChunksDBHandlerFunctions defines the interface for Chunks database operations.
type ChunksDBHandlerFunctions interface {
InsertChunk(chunk *model.Chunk) error
SelectChunk(id uuid.UUID) (*model.Chunk, error)
SelectAllChunksByDocument(documentRID uuid.UUID) ([]*model.Chunk, error)
SelectAllChunksByPathDescendant(path string) ([]*model.Chunk, error)
SelectAllChunksByPathAncestor(path string) ([]*model.Chunk, error)
SelectChunksBySimilarity(embedding []float32, limit int, threshold float64, documentRIDs []uuid.UUID) ([]*model.Chunk, error)
SelectChunksBySimilarityWithContext(embedding []float32, limit int, includeAncestors bool, includeDescendants bool, threshold float64, documentRIDs []uuid.UUID) ([]*model.Chunk, error)
DeleteChunk(id uuid.UUID) error
UpdateChunkEmbedding(id uuid.UUID, embedding []float32) error
}
// ChunksDBHandler handles chunk-related database operations
type ChunksDBHandler struct {
db *helper.Database
edgesHandler *EdgesDBHandler // For graph operations
}
// NewChunksDBHandler creates a new chunks database handler.
// It initializes the database connection and loads chunk-related SQL functions.
// If force is true, it will reload the SQL functions even if they already exist.
func NewChunksDBHandler(db *helper.Database, edgesHandler *EdgesDBHandler, embeddingDim int, force bool) (*ChunksDBHandler, error) {
if db == nil {
return nil, helper.NewError("database connection validation", fmt.Errorf("database connection is nil"))
}
chunksDbHandler := &ChunksDBHandler{
db: db,
edgesHandler: edgesHandler,
}
err := loadSql.LoadChunksSql(chunksDbHandler.db.Instance, force)
if err != nil {
return nil, helper.NewError("load chunks sql", err)
}
err = chunksDbHandler.CreateTable(embeddingDim)
if err != nil {
return nil, helper.NewError("create table", err)
}
db.Logger.Info("Initialized ChunksDBHandler")
return chunksDbHandler, nil
}
// CreateTable creates the 'chunks' table in the database.
// If the table already exists, it does not create it again.
// It also creates all necessary extensions, indexes, and triggers.
func (h *ChunksDBHandler) CreateTable(embeddingDim int) error {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// Use the SQL init() function to create all tables, triggers, and indexes
_, err := h.db.Instance.ExecContext(ctx, `SELECT init_chunks($1);`, embeddingDim)
if err != nil {
log.Panicf("error initializing chunks table: %#v", err)
}
h.db.Logger.Info("Checked/created table chunks")
return nil
}
// InsertChunk inserts a new chunk
func (h *ChunksDBHandler) InsertChunk(chunk *model.Chunk) error {
var embeddingParam interface{}
if len(chunk.Embedding) > 0 {
embeddingVector := pgvector.NewVector(chunk.Embedding)
embeddingParam = &embeddingVector
} else {
embeddingParam = nil
}
row := h.db.Instance.QueryRow(
`SELECT * FROM insert_chunk($1, $2, $3, $4, $5, $6, $7, $8)`,
chunk.DocumentID,
chunk.Content,
chunk.Path,
embeddingParam,
chunk.StartPos,
chunk.EndPos,
chunk.ChunkIndex,
chunk.Metadata,
)
var embeddingVec *pgvector.Vector
err := row.Scan(
&chunk.ID,
&chunk.DocumentID,
&chunk.DocumentRID,
&chunk.Content,
&chunk.Path,
&embeddingVec,
&chunk.StartPos,
&chunk.EndPos,
&chunk.ChunkIndex,
&chunk.Metadata,
&chunk.CreatedAt,
)
if err != nil {
return helper.NewError("scan", err)
}
if embeddingVec != nil {
chunk.Embedding = embeddingVec.Slice()
}
return nil
}
// SelectChunk retrieves a chunk by ID
func (h *ChunksDBHandler) SelectChunk(id uuid.UUID) (*model.Chunk, error) {
row := h.db.Instance.QueryRow(
`SELECT * FROM select_chunk($1)`,
id,
)
chunk := &model.Chunk{}
var embeddingVec *pgvector.Vector
err := row.Scan(
&chunk.ID,
&chunk.DocumentID,
&chunk.DocumentRID,
&chunk.Content,
&chunk.Path,
&embeddingVec,
&chunk.StartPos,
&chunk.EndPos,
&chunk.ChunkIndex,
&chunk.Metadata,
&chunk.CreatedAt,
)
if err != nil {
return nil, helper.NewError("scan", err)
}
if embeddingVec != nil {
chunk.Embedding = embeddingVec.Slice()
}
return chunk, nil
}
// SelectAllChunksByDocument retrieves all chunks for a document
func (h *ChunksDBHandler) SelectAllChunksByDocument(documentRID uuid.UUID) ([]*model.Chunk, error) {
rows, err := h.db.Instance.Query(
`SELECT * FROM select_chunks_by_document($1)`,
documentRID,
)
if err != nil {
return nil, helper.NewError("query", err)
}
defer rows.Close()
var chunks []*model.Chunk
for rows.Next() {
chunk := &model.Chunk{}
var embeddingVec *pgvector.Vector
var metadataJSON []byte
err := rows.Scan(
&chunk.ID,
&chunk.DocumentID,
&chunk.DocumentRID,
&chunk.Content,
&chunk.Path,
&embeddingVec,
&chunk.StartPos,
&chunk.EndPos,
&chunk.ChunkIndex,
&metadataJSON,
&chunk.CreatedAt,
)
if err != nil {
return nil, helper.NewError("scan", err)
}
if embeddingVec != nil {
chunk.Embedding = embeddingVec.Slice()
}
if err := json.Unmarshal(metadataJSON, &chunk.Metadata); err != nil {
return nil, helper.NewError("unmarshaling metadata", err)
}
chunks = append(chunks, chunk)
}
err = rows.Err()
if err != nil {
return nil, helper.NewError("rows error", err)
}
return chunks, nil
}
// SelectAllChunksByPathDescendant retrieves chunks that are descendants of the given path
func (h *ChunksDBHandler) SelectAllChunksByPathDescendant(path string) ([]*model.Chunk, error) {
rows, err := h.db.Instance.Query(
`SELECT * FROM select_chunks_by_path_descendant($1)`,
path,
)
if err != nil {
return nil, helper.NewError("query", err)
}
defer rows.Close()
var chunks []*model.Chunk
for rows.Next() {
chunk := &model.Chunk{}
var embeddingVec *pgvector.Vector
err := rows.Scan(
&chunk.ID,
&chunk.DocumentID,
&chunk.DocumentRID,
&chunk.Content,
&chunk.Path,
&embeddingVec,
&chunk.StartPos,
&chunk.EndPos,
&chunk.ChunkIndex,
&chunk.Metadata,
&chunk.CreatedAt,
)
if err != nil {
return nil, helper.NewError("scan", err)
}
if embeddingVec != nil {
chunk.Embedding = embeddingVec.Slice()
}
chunks = append(chunks, chunk)
}
err = rows.Err()
if err != nil {
return nil, helper.NewError("rows error", err)
}
return chunks, nil
}
// SelectAllChunksByPathAncestor retrieves chunks that are ancestors of the given path
func (h *ChunksDBHandler) SelectAllChunksByPathAncestor(path string) ([]*model.Chunk, error) {
rows, err := h.db.Instance.Query(
`SELECT * FROM select_chunks_by_path_ancestor($1)`,
path,
)
if err != nil {
return nil, helper.NewError("query", err)
}
defer rows.Close()
var chunks []*model.Chunk
for rows.Next() {
chunk := &model.Chunk{}
var embeddingVec *pgvector.Vector
err := rows.Scan(
&chunk.ID,
&chunk.DocumentID,
&chunk.DocumentRID,
&chunk.Content,
&chunk.Path,
&embeddingVec,
&chunk.StartPos,
&chunk.EndPos,
&chunk.ChunkIndex,
&chunk.Metadata,
&chunk.CreatedAt,
)
if err != nil {
return nil, helper.NewError("scan", err)
}
if embeddingVec != nil {
chunk.Embedding = embeddingVec.Slice()
}
chunks = append(chunks, chunk)
}
err = rows.Err()
if err != nil {
return nil, helper.NewError("rows error", err)
}
return chunks, nil
}
// SelectSiblingChunks retrieves chunks that are siblings of the given path (same parent, same level)
func (h *ChunksDBHandler) SelectSiblingChunks(path string) ([]*model.Chunk, error) {
rows, err := h.db.Instance.Query(
`SELECT * FROM select_sibling_chunks($1)`,
path,
)
if err != nil {
return nil, helper.NewError("query", err)
}
defer rows.Close()
var chunks []*model.Chunk
for rows.Next() {
chunk := &model.Chunk{}
var embeddingVec *pgvector.Vector
err := rows.Scan(
&chunk.ID,
&chunk.DocumentID,
&chunk.DocumentRID,
&chunk.Content,
&chunk.Path,
&embeddingVec,
&chunk.StartPos,
&chunk.EndPos,
&chunk.ChunkIndex,
&chunk.Metadata,
&chunk.CreatedAt,
)
if err != nil {
return nil, helper.NewError("scan", err)
}
if embeddingVec != nil {
chunk.Embedding = embeddingVec.Slice()
}
chunks = append(chunks, chunk)
}
err = rows.Err()
if err != nil {
return nil, helper.NewError("rows error", err)
}
return chunks, nil
}
// SelectChunksBySimilarity performs vector similarity search
// If documentRIDs is nil or empty, searches across all documents
func (h *ChunksDBHandler) SelectChunksBySimilarity(embedding []float32, limit int, threshold float64, documentRIDs []uuid.UUID) ([]*model.Chunk, error) {
embeddingVector := pgvector.NewVector(embedding)
// Convert documentRIDs to PostgreSQL UUID array format
var documentRIDsParam interface{}
if len(documentRIDs) > 0 {
documentRIDsParam = pq.Array(documentRIDs)
} else {
documentRIDsParam = nil
}
rows, err := h.db.Instance.Query(
`SELECT * FROM select_chunks_by_similarity($1, $2, $3, $4)`,
embeddingVector,
limit,
threshold,
documentRIDsParam,
)
if err != nil {
return nil, helper.NewError("query", err)
}
defer rows.Close()
var results []*model.Chunk
for rows.Next() {
chunk := &model.Chunk{}
var embeddingVec *pgvector.Vector
var similarity sql.NullFloat64
err := rows.Scan(
&chunk.ID,
&chunk.DocumentID,
&chunk.DocumentRID,
&chunk.Content,
&chunk.Path,
&embeddingVec,
&chunk.StartPos,
&chunk.EndPos,
&chunk.ChunkIndex,
&chunk.Metadata,
&chunk.CreatedAt,
&similarity,
)
if err != nil {
return nil, helper.NewError("scan", err)
}
if embeddingVec != nil {
chunk.Embedding = embeddingVec.Slice()
}
// Ensure similarity is always set
if similarity.Valid {
chunk.Similarity = &similarity.Float64
} else {
zero := 0.0
chunk.Similarity = &zero
}
results = append(results, chunk)
}
err = rows.Err()
if err != nil {
return nil, helper.NewError("rows error", err)
}
return results, nil
}
// SelectChunksBySimilarityWithContext performs vector similarity search with hierarchical context
// If documentRIDs is nil or empty, searches across all documents
func (h *ChunksDBHandler) SelectChunksBySimilarityWithContext(
embedding []float32,
limit int,
includeAncestors bool,
includeDescendants bool,
threshold float64,
documentRIDs []uuid.UUID,
) ([]*model.Chunk, error) {
embeddingVector := pgvector.NewVector(embedding)
// Convert documentRIDs to PostgreSQL UUID array format
var documentRIDsParam interface{}
if len(documentRIDs) > 0 {
documentRIDsParam = pq.Array(documentRIDs)
} else {
documentRIDsParam = nil
}
rows, err := h.db.Instance.Query(
`SELECT * FROM select_chunks_by_similarity_with_context($1, $2, $3, $4, $5, $6)`,
embeddingVector,
limit,
includeAncestors,
includeDescendants,
threshold,
documentRIDsParam,
)
if err != nil {
return nil, helper.NewError("query", err)
}
defer rows.Close()
var results []*model.Chunk
for rows.Next() {
chunk := &model.Chunk{}
var similarity sql.NullFloat64
var embeddingVec *pgvector.Vector
err := rows.Scan(
&chunk.ID,
&chunk.DocumentID,
&chunk.DocumentRID,
&chunk.Content,
&chunk.Path,
&embeddingVec,
&chunk.StartPos,
&chunk.EndPos,
&chunk.ChunkIndex,
&chunk.Metadata,
&chunk.CreatedAt,
&similarity,
&chunk.IsMatch,
)
if err != nil {
return nil, helper.NewError("scan", err)
}
if embeddingVec != nil {
chunk.Embedding = embeddingVec.Slice()
}
if similarity.Valid {
chunk.Similarity = &similarity.Float64
}
results = append(results, chunk)
}
err = rows.Err()
if err != nil {
return nil, helper.NewError("rows error", err)
}
return results, nil
}
// DeleteChunk deletes a chunk by ID
func (h *ChunksDBHandler) DeleteChunk(id uuid.UUID) error {
_, err := h.db.Instance.Exec(
`SELECT delete_chunk($1)`,
id,
)
if err != nil {
return helper.NewError("exec", err)
}
return nil
}
// UpdateChunkEmbedding updates the embedding of a chunk
func (h *ChunksDBHandler) UpdateChunkEmbedding(id uuid.UUID, embedding []float32) error {
embeddingVector := pgvector.NewVector(embedding)
_, err := h.db.Instance.Exec(
`SELECT * FROM update_chunk_embedding($1, $2)`,
id,
embeddingVector,
)
if err != nil {
return helper.NewError("exec", err)
}
return nil
}
package database
import (
"context"
"fmt"
"log"
"time"
"github.com/google/uuid"
"github.com/siherrmann/grapher/helper"
"github.com/siherrmann/grapher/model"
"github.com/siherrmann/grapher/sql"
)
// DocumentsDBHandlerFunctions defines the interface for Documents database operations.
type DocumentsDBHandlerFunctions interface {
InsertDocument(doc *model.Document) error
SelectDocument(rid uuid.UUID) (*model.Document, error)
SelectAllDocuments(lastCreatedAt *time.Time, limit int) ([]*model.Document, error)
SelectDocumentsBySearch(searchTerm string, limit int) ([]*model.Document, error)
UpdateDocument(doc *model.Document) error
DeleteDocument(rid uuid.UUID) error
}
// DocumentsDBHandler handles document-related database operations
type DocumentsDBHandler struct {
db *helper.Database
}
// NewDocumentsDBHandler creates a new documents database handler.
// It initializes the database connection and loads document-related SQL functions.
// If force is true, it will reload the SQL functions even if they already exist.
func NewDocumentsDBHandler(db *helper.Database, force bool) (*DocumentsDBHandler, error) {
if db == nil {
return nil, helper.NewError("database connection validation", fmt.Errorf("database connection is nil"))
}
documentsDbHandler := &DocumentsDBHandler{
db: db,
}
err := sql.LoadDocumentsSql(documentsDbHandler.db.Instance, force)
if err != nil {
return nil, helper.NewError("load documents sql", err)
}
err = documentsDbHandler.CreateTable()
if err != nil {
return nil, helper.NewError("create table", err)
}
db.Logger.Info("Initialized DocumentsDBHandler")
return documentsDbHandler, nil
}
// CreateTable creates the 'documents' table in the database.
// If the table already exists, it does not create it again.
// It also creates all necessary indexes and triggers.
func (h *DocumentsDBHandler) CreateTable() error {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// Use the SQL init() function to create all tables, triggers, and indexes
_, err := h.db.Instance.ExecContext(ctx, `SELECT init_documents();`)
if err != nil {
log.Panicf("error initializing documents table: %#v", err)
}
h.db.Logger.Info("Checked/created table documents")
return nil
}
// InsertDocument inserts a new document
func (h *DocumentsDBHandler) InsertDocument(doc *model.Document) error {
row := h.db.Instance.QueryRow(
`SELECT * FROM insert_document($1, $2, $3)`,
doc.Title,
doc.Source,
doc.Metadata,
)
err := row.Scan(
&doc.ID,
&doc.RID,
&doc.Title,
&doc.Source,
&doc.Metadata,
&doc.CreatedAt,
&doc.UpdatedAt,
)
if err != nil {
return helper.NewError("scan", err)
}
return nil
}
// SelectDocument retrieves a document by RID
func (h *DocumentsDBHandler) SelectDocument(rid uuid.UUID) (*model.Document, error) {
doc := &model.Document{}
row := h.db.Instance.QueryRow(
`SELECT * FROM select_document($1)`,
rid,
)
err := row.Scan(
&doc.ID,
&doc.RID,
&doc.Title,
&doc.Source,
&doc.Metadata,
&doc.CreatedAt,
&doc.UpdatedAt,
)
if err != nil {
return nil, helper.NewError("scan", err)
}
return doc, nil
}
// SelectAllDocuments retrieves all documents with pagination
func (h *DocumentsDBHandler) SelectAllDocuments(lastCreatedAt *time.Time, limit int) ([]*model.Document, error) {
rows, err := h.db.Instance.Query(
`SELECT * FROM select_all_documents($1, $2)`,
lastCreatedAt,
limit,
)
if err != nil {
return nil, helper.NewError("query", err)
}
defer rows.Close()
var documents []*model.Document
for rows.Next() {
doc := &model.Document{}
err := rows.Scan(
&doc.ID,
&doc.RID,
&doc.Title,
&doc.Source,
&doc.Metadata,
&doc.CreatedAt,
&doc.UpdatedAt,
)
if err != nil {
return nil, helper.NewError("scan", err)
}
documents = append(documents, doc)
}
err = rows.Err()
if err != nil {
return nil, helper.NewError("rows error", err)
}
return documents, nil
}
// SelectDocumentsBySearch searches documents by title or source
func (h *DocumentsDBHandler) SelectDocumentsBySearch(searchTerm string, limit int) ([]*model.Document, error) {
rows, err := h.db.Instance.Query(
`SELECT * FROM search_documents($1, $2)`,
searchTerm,
limit,
)
if err != nil {
return nil, helper.NewError("query", err)
}
defer rows.Close()
var documents []*model.Document
for rows.Next() {
doc := &model.Document{}
err := rows.Scan(
&doc.ID,
&doc.RID,
&doc.Title,
&doc.Source,
&doc.Metadata,
&doc.CreatedAt,
&doc.UpdatedAt,
)
if err != nil {
return nil, helper.NewError("scan", err)
}
documents = append(documents, doc)
}
err = rows.Err()
if err != nil {
return nil, helper.NewError("rows error", err)
}
return documents, nil
}
// UpdateDocument updates a document
func (h *DocumentsDBHandler) UpdateDocument(doc *model.Document) error {
row := h.db.Instance.QueryRow(
`SELECT * FROM update_document($1, $2, $3, $4)`,
doc.RID,
doc.Title,
doc.Source,
doc.Metadata,
)
err := row.Scan(
&doc.ID,
&doc.RID,
&doc.Title,
&doc.Source,
&doc.Metadata,
&doc.CreatedAt,
&doc.UpdatedAt,
)
if err != nil {
return helper.NewError("scan", err)
}
return nil
}
// DeleteDocument deletes a document by RID
func (h *DocumentsDBHandler) DeleteDocument(rid uuid.UUID) error {
_, err := h.db.Instance.Exec(
`SELECT delete_document($1)`,
rid,
)
if err != nil {
return helper.NewError("exec", err)
}
return nil
}
package database
import (
"context"
"database/sql"
"fmt"
"log"
"time"
"github.com/google/uuid"
"github.com/siherrmann/grapher/helper"
"github.com/siherrmann/grapher/model"
loadSql "github.com/siherrmann/grapher/sql"
)
// EdgesDBHandlerFunctions defines the interface for Edges database operations.
type EdgesDBHandlerFunctions interface {
InsertEdge(edge *model.Edge) error
SelectEdge(id uuid.UUID) (*model.Edge, error)
SelectEdgesFromChunk(chunkID uuid.UUID, edgeType *model.EdgeType) ([]*model.Edge, error)
SelectEdgesToChunk(chunkID uuid.UUID, edgeType *model.EdgeType) ([]*model.Edge, error)
SelectEdgesConnectedToChunk(chunkID uuid.UUID, edgeType *model.EdgeType) ([]*model.EdgeConnection, error)
SelectEdgesFromEntity(entityID uuid.UUID, edgeType *model.EdgeType) ([]*model.Edge, error)
SelectEdgesToEntity(entityID uuid.UUID, edgeType *model.EdgeType) ([]*model.Edge, error)
DeleteEdge(id uuid.UUID) error
UpdateEdgeWeight(id uuid.UUID, weight float64) error
TraverseBFSFromChunk(startChunkID uuid.UUID, maxDepth int, edgeType *model.EdgeType) ([]*model.TraversalNode, error)
}
// EdgesDBHandler handles edge-related database operations
type EdgesDBHandler struct {
db *helper.Database
}
// NewEdgesDBHandler creates a new edges database handler.
// It initializes the database connection and loads edge-related SQL functions.
// If force is true, it will reload the SQL functions even if they already exist.
func NewEdgesDBHandler(db *helper.Database, force bool) (*EdgesDBHandler, error) {
if db == nil {
return nil, helper.NewError("database connection validation", fmt.Errorf("database connection is nil"))
}
edgesDbHandler := &EdgesDBHandler{
db: db,
}
err := loadSql.LoadEdgesSql(edgesDbHandler.db.Instance, force)
if err != nil {
return nil, helper.NewError("load edges sql", err)
}
err = edgesDbHandler.CreateTable()
if err != nil {
return nil, helper.NewError("create table", err)
}
db.Logger.Info("Initialized EdgesDBHandler")
return edgesDbHandler, nil
}
// CreateTable creates the 'edges' table in the database.
// If the table already exists, it does not create it again.
// It also creates the edge_type enum and all necessary indexes.
func (h *EdgesDBHandler) CreateTable() error {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// Use the SQL init() function to create all tables, triggers, and indexes
_, err := h.db.Instance.ExecContext(ctx, `SELECT init_edges();`)
if err != nil {
log.Panicf("error initializing edges table: %#v", err)
}
h.db.Logger.Info("Checked/created table edges")
return nil
}
// InsertEdge inserts a new edge
func (h *EdgesDBHandler) InsertEdge(edge *model.Edge) error {
row := h.db.Instance.QueryRow(
`SELECT * FROM insert_edge($1, $2, $3, $4, $5, $6, $7, $8)`,
edge.SourceChunkID,
edge.TargetChunkID,
edge.SourceEntityID,
edge.TargetEntityID,
edge.EdgeType,
edge.Weight,
edge.Bidirectional,
edge.Metadata,
)
err := row.Scan(
&edge.ID,
&edge.SourceChunkID,
&edge.TargetChunkID,
&edge.SourceEntityID,
&edge.TargetEntityID,
&edge.EdgeType,
&edge.Weight,
&edge.Bidirectional,
&edge.Metadata,
&edge.CreatedAt,
)
if err != nil {
return helper.NewError("scan", err)
}
return nil
}
// SelectEdge retrieves an edge by ID
func (h *EdgesDBHandler) SelectEdge(id uuid.UUID) (*model.Edge, error) {
row := h.db.Instance.QueryRow(
`SELECT * FROM select_edge($1)`,
id,
)
edge := &model.Edge{}
err := row.Scan(
&edge.ID,
&edge.SourceChunkID,
&edge.TargetChunkID,
&edge.SourceEntityID,
&edge.TargetEntityID,
&edge.EdgeType,
&edge.Weight,
&edge.Bidirectional,
&edge.Metadata,
&edge.CreatedAt,
)
if err != nil {
return nil, helper.NewError("scan", err)
}
return edge, nil
}
// SelectEdgesFromChunk retrieves edges originating from a chunk
func (h *EdgesDBHandler) SelectEdgesFromChunk(chunkID uuid.UUID, edgeType *model.EdgeType) ([]*model.Edge, error) {
var rows *sql.Rows
var err error
if edgeType != nil {
rows, err = h.db.Instance.Query(
`SELECT * FROM select_edges_from_chunk($1, $2)`,
chunkID,
*edgeType,
)
} else {
rows, err = h.db.Instance.Query(
`SELECT * FROM select_edges_from_chunk($1, NULL)`,
chunkID,
)
}
if err != nil {
return nil, helper.NewError("query", err)
}
defer rows.Close()
var edges []*model.Edge
for rows.Next() {
edge := &model.Edge{}
err := rows.Scan(
&edge.ID,
&edge.SourceChunkID,
&edge.TargetChunkID,
&edge.SourceEntityID,
&edge.TargetEntityID,
&edge.EdgeType,
&edge.Weight,
&edge.Bidirectional,
&edge.Metadata,
&edge.CreatedAt,
)
if err != nil {
return nil, helper.NewError("scan", err)
}
edges = append(edges, edge)
}
err = rows.Err()
if err != nil {
return nil, helper.NewError("rows error", err)
}
return edges, nil
}
// SelectEdgesToChunk retrieves edges targeting a chunk
func (h *EdgesDBHandler) SelectEdgesToChunk(chunkID uuid.UUID, edgeType *model.EdgeType) ([]*model.Edge, error) {
var rows *sql.Rows
var err error
if edgeType != nil {
rows, err = h.db.Instance.Query(
`SELECT * FROM select_edges_to_chunk($1, $2)`,
chunkID,
*edgeType,
)
} else {
rows, err = h.db.Instance.Query(
`SELECT * FROM select_edges_to_chunk($1, NULL)`,
chunkID,
)
}
if err != nil {
return nil, helper.NewError("query", err)
}
defer rows.Close()
var edges []*model.Edge
for rows.Next() {
edge := &model.Edge{}
err := rows.Scan(
&edge.ID,
&edge.SourceChunkID,
&edge.TargetChunkID,
&edge.SourceEntityID,
&edge.TargetEntityID,
&edge.EdgeType,
&edge.Weight,
&edge.Bidirectional,
&edge.Metadata,
&edge.CreatedAt,
)
if err != nil {
return nil, helper.NewError("scan", err)
}
edges = append(edges, edge)
}
err = rows.Err()
if err != nil {
return nil, helper.NewError("rows error", err)
}
return edges, nil
}
// SelectEdgesConnectedToChunk retrieves all edges connected to a chunk (both directions)
func (h *EdgesDBHandler) SelectEdgesConnectedToChunk(chunkID uuid.UUID, edgeType *model.EdgeType) ([]*model.EdgeConnection, error) {
var rows *sql.Rows
var err error
if edgeType != nil {
rows, err = h.db.Instance.Query(
`SELECT * FROM select_edges_connected_to_chunk($1, $2)`,
chunkID,
*edgeType,
)
} else {
rows, err = h.db.Instance.Query(
`SELECT * FROM select_edges_connected_to_chunk($1, NULL)`,
chunkID,
)
}
if err != nil {
return nil, helper.NewError("query", err)
}
defer rows.Close()
var connections []*model.EdgeConnection
for rows.Next() {
edge := &model.Edge{}
var isOutgoing bool
err := rows.Scan(
&edge.ID,
&edge.SourceChunkID,
&edge.TargetChunkID,
&edge.SourceEntityID,
&edge.TargetEntityID,
&edge.EdgeType,
&edge.Weight,
&edge.Bidirectional,
&edge.Metadata,
&edge.CreatedAt,
&isOutgoing,
)
if err != nil {
return nil, helper.NewError("scan", err)
}
connections = append(connections, &model.EdgeConnection{
Edge: edge,
IsOutgoing: isOutgoing,
})
}
err = rows.Err()
if err != nil {
return nil, helper.NewError("rows error", err)
}
return connections, nil
}
// SelectEdgesFromEntity retrieves edges originating from an entity
func (h *EdgesDBHandler) SelectEdgesFromEntity(entityID uuid.UUID, edgeType *model.EdgeType) ([]*model.Edge, error) {
var rows *sql.Rows
var err error
if edgeType != nil {
rows, err = h.db.Instance.Query(
`SELECT * FROM select_edges_from_entity($1, $2)`,
entityID,
*edgeType,
)
} else {
rows, err = h.db.Instance.Query(
`SELECT * FROM select_edges_from_entity($1, NULL)`,
entityID,
)
}
if err != nil {
return nil, helper.NewError("query", err)
}
defer rows.Close()
var edges []*model.Edge
for rows.Next() {
edge := &model.Edge{}
err := rows.Scan(
&edge.ID,
&edge.SourceChunkID,
&edge.TargetChunkID,
&edge.SourceEntityID,
&edge.TargetEntityID,
&edge.EdgeType,
&edge.Weight,
&edge.Bidirectional,
&edge.Metadata,
&edge.CreatedAt,
)
if err != nil {
return nil, helper.NewError("scan", err)
}
edges = append(edges, edge)
}
err = rows.Err()
if err != nil {
return nil, helper.NewError("rows error", err)
}
return edges, nil
}
// SelectEdgesToEntity retrieves edges targeting an entity
func (h *EdgesDBHandler) SelectEdgesToEntity(entityID uuid.UUID, edgeType *model.EdgeType) ([]*model.Edge, error) {
var rows *sql.Rows
var err error
if edgeType != nil {
rows, err = h.db.Instance.Query(
`SELECT * FROM select_edges_to_entity($1, $2)`,
entityID,
*edgeType,
)
} else {
rows, err = h.db.Instance.Query(
`SELECT * FROM select_edges_to_entity($1, NULL)`,
entityID,
)
}
if err != nil {
return nil, helper.NewError("query", err)
}
defer rows.Close()
var edges []*model.Edge
for rows.Next() {
edge := &model.Edge{}
err := rows.Scan(
&edge.ID,
&edge.SourceChunkID,
&edge.TargetChunkID,
&edge.SourceEntityID,
&edge.TargetEntityID,
&edge.EdgeType,
&edge.Weight,
&edge.Bidirectional,
&edge.Metadata,
&edge.CreatedAt,
)
if err != nil {
return nil, helper.NewError("scan", err)
}
edges = append(edges, edge)
}
err = rows.Err()
if err != nil {
return nil, helper.NewError("rows error", err)
}
return edges, nil
}
// DeleteEdge deletes an edge by ID
func (h *EdgesDBHandler) DeleteEdge(id uuid.UUID) error {
_, err := h.db.Instance.Exec(
`SELECT delete_edge($1)`,
id,
)
if err != nil {
return helper.NewError("exec", err)
}
return nil
}
// UpdateEdgeWeight updates the weight of an edge
func (h *EdgesDBHandler) UpdateEdgeWeight(id uuid.UUID, weight float64) error {
_, err := h.db.Instance.Exec(
`SELECT * FROM update_edge_weight($1, $2)`,
id,
weight,
)
if err != nil {
return helper.NewError("exec", err)
}
return nil
}
// TraverseBFSFromChunk performs breadth-first search from a starting chunk
func (h *EdgesDBHandler) TraverseBFSFromChunk(startChunkID uuid.UUID, maxDepth int, edgeType *model.EdgeType) ([]*model.TraversalNode, error) {
var rows *sql.Rows
var err error
if edgeType != nil {
rows, err = h.db.Instance.Query(
`SELECT * FROM traverse_bfs_from_chunk($1, $2, $3)`,
startChunkID,
maxDepth,
*edgeType,
)
} else {
rows, err = h.db.Instance.Query(
`SELECT * FROM traverse_bfs_from_chunk($1, $2, NULL)`,
startChunkID,
maxDepth,
)
}
if err != nil {
return nil, helper.NewError("query", err)
}
defer rows.Close()
var nodes []*model.TraversalNode
for rows.Next() {
node := &model.TraversalNode{}
var pathArray []byte
err := rows.Scan(
&node.ChunkID,
&node.Depth,
&pathArray,
)
if err != nil {
return nil, helper.NewError("scan", err)
}
// Parse PostgreSQL UUID array
// Format: {uuid1,uuid2,uuid3}
if err := parseUUIDArray(pathArray, &node.Path); err != nil {
return nil, helper.NewError("parsing path array", err)
}
nodes = append(nodes, node)
}
err = rows.Err()
if err != nil {
return nil, helper.NewError("rows error", err)
}
return nodes, nil
}
// parseUUIDArray parses PostgreSQL UUID array format
func parseUUIDArray(data []byte, result *[]uuid.UUID) error {
// PostgreSQL array format: {uuid1,uuid2,uuid3}
str := string(data)
if len(str) < 2 || str[0] != '{' || str[len(str)-1] != '}' {
return helper.NewError("invalid array format", fmt.Errorf("%s", str))
}
// Remove braces
str = str[1 : len(str)-1]
if str == "" {
*result = []uuid.UUID{}
return nil
}
// Split by comma
parts := []string{}
current := ""
for _, ch := range str {
if ch == ',' {
parts = append(parts, current)
current = ""
} else {
current += string(ch)
}
}
if current != "" {
parts = append(parts, current)
}
// Parse each UUID
*result = make([]uuid.UUID, 0, len(parts))
for _, part := range parts {
id, err := uuid.Parse(part)
if err != nil {
return helper.NewError(fmt.Sprintf("parsing UUID %s", part), err)
}
*result = append(*result, id)
}
return nil
}
package database
import (
"context"
"fmt"
"log"
"time"
"github.com/google/uuid"
"github.com/siherrmann/grapher/helper"
"github.com/siherrmann/grapher/model"
"github.com/siherrmann/grapher/sql"
)
// EntitiesDBHandlerFunctions defines the interface for Entities database operations.
type EntitiesDBHandlerFunctions interface {
InsertEntity(entity *model.Entity) error
SelectEntity(id uuid.UUID) (*model.Entity, error)
SelectEntityByName(name string, entityType string) (*model.Entity, error)
SelectEntitiesBySearch(searchTerm string, entityType *string, limit int) ([]*model.Entity, error)
SelectEntitiesByType(entityType string, limit int) ([]*model.Entity, error)
DeleteEntity(id uuid.UUID) error
UpdateEntityMetadata(id uuid.UUID, metadata map[string]interface{}) error
SelectChunksMentioningEntity(entityID uuid.UUID) ([]*model.ChunkMention, error)
}
// EntitiesDBHandler handles entity-related database operations
type EntitiesDBHandler struct {
db *helper.Database
}
// NewEntitiesDBHandler creates a new entities database handler.
// It initializes the database connection and loads entity-related SQL functions.
// If force is true, it will reload the SQL functions even if they already exist.
func NewEntitiesDBHandler(db *helper.Database, force bool) (*EntitiesDBHandler, error) {
if db == nil {
return nil, helper.NewError("database connection validation", fmt.Errorf("database connection is nil"))
}
entitiesDbHandler := &EntitiesDBHandler{
db: db,
}
err := sql.LoadEntitiesSql(entitiesDbHandler.db.Instance, force)
if err != nil {
return nil, helper.NewError("load entities sql", err)
}
err = entitiesDbHandler.CreateTable()
if err != nil {
return nil, helper.NewError("create table", err)
}
db.Logger.Info("Initialized EntitiesDBHandler")
return entitiesDbHandler, nil
}
// CreateTable creates the 'entities' table in the database.
// If the table already exists, it does not create it again.
// It also creates all necessary indexes.
func (h *EntitiesDBHandler) CreateTable() error {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// Use the SQL init() function to create all tables, triggers, and indexes
_, err := h.db.Instance.ExecContext(ctx, `SELECT init_entities();`)
if err != nil {
log.Panicf("error initializing entities table: %#v", err)
}
h.db.Logger.Info("Checked/created table entities")
return nil
}
// InsertEntity inserts a new entity (or updates if exists)
func (h *EntitiesDBHandler) InsertEntity(entity *model.Entity) error {
row := h.db.Instance.QueryRow(
`SELECT * FROM insert_entity($1, $2, $3)`,
entity.Name,
entity.Type,
entity.Metadata,
)
err := row.Scan(
&entity.ID,
&entity.Name,
&entity.Type,
&entity.Metadata,
&entity.CreatedAt,
)
if err != nil {
return helper.NewError("scan", err)
}
return nil
}
// SelectEntity retrieves an entity by ID
func (h *EntitiesDBHandler) SelectEntity(id uuid.UUID) (*model.Entity, error) {
entity := &model.Entity{}
row := h.db.Instance.QueryRow(
`SELECT * FROM select_entity($1)`,
id,
)
err := row.Scan(
&entity.ID,
&entity.Name,
&entity.Type,
&entity.Metadata,
&entity.CreatedAt,
)
if err != nil {
return nil, helper.NewError("scan", err)
}
return entity, nil
}
// SelectEntityByName retrieves an entity by name and type
func (h *EntitiesDBHandler) SelectEntityByName(name string, entityType string) (*model.Entity, error) {
entity := &model.Entity{}
row := h.db.Instance.QueryRow(
`SELECT * FROM select_entity_by_name($1, $2)`,
name,
entityType,
)
err := row.Scan(
&entity.ID,
&entity.Name,
&entity.Type,
&entity.Metadata,
&entity.CreatedAt,
)
if err != nil {
return nil, helper.NewError("scan", err)
}
return entity, nil
}
// SelectEntitiesBySearch searches entities by name pattern
func (h *EntitiesDBHandler) SelectEntitiesBySearch(searchTerm string, entityType *string, limit int) ([]*model.Entity, error) {
rows, err := h.db.Instance.Query(
`SELECT * FROM search_entities($1, $2, $3)`,
searchTerm,
entityType,
limit,
)
if err != nil {
return nil, helper.NewError("query", err)
}
defer rows.Close()
var entities []*model.Entity
for rows.Next() {
entity := &model.Entity{}
err := rows.Scan(
&entity.ID,
&entity.Name,
&entity.Type,
&entity.Metadata,
&entity.CreatedAt,
)
if err != nil {
return nil, helper.NewError("scan", err)
}
entities = append(entities, entity)
}
err = rows.Err()
if err != nil {
return nil, helper.NewError("rows error", err)
}
return entities, nil
}
// SelectEntitiesByType retrieves entities by type
func (h *EntitiesDBHandler) SelectEntitiesByType(entityType string, limit int) ([]*model.Entity, error) {
rows, err := h.db.Instance.Query(
`SELECT * FROM select_entities_by_type($1, $2)`,
entityType,
limit,
)
if err != nil {
return nil, helper.NewError("query", err)
}
defer rows.Close()
var entities []*model.Entity
for rows.Next() {
entity := &model.Entity{}
err := rows.Scan(
&entity.ID,
&entity.Name,
&entity.Type,
&entity.Metadata,
&entity.CreatedAt,
)
if err != nil {
return nil, helper.NewError("scan", err)
}
entities = append(entities, entity)
}
err = rows.Err()
if err != nil {
return nil, helper.NewError("rows error", err)
}
return entities, nil
}
// DeleteEntity deletes an entity by ID
func (h *EntitiesDBHandler) DeleteEntity(id uuid.UUID) error {
_, err := h.db.Instance.Exec(
`SELECT delete_entity($1)`,
id,
)
if err != nil {
return helper.NewError("exec", err)
}
return nil
}
// UpdateEntityMetadata updates the metadata of an entity
func (h *EntitiesDBHandler) UpdateEntityMetadata(id uuid.UUID, metadata model.Metadata) error {
_, err := h.db.Instance.Exec(
`SELECT * FROM update_entity_metadata($1, $2)`,
id,
metadata,
)
if err != nil {
return helper.NewError("exec", err)
}
return nil
}
// SelectChunksMentioningEntity retrieves chunks that mention an entity
func (h *EntitiesDBHandler) SelectChunksMentioningEntity(entityID uuid.UUID) ([]*model.ChunkMention, error) {
rows, err := h.db.Instance.Query(
`SELECT * FROM select_chunks_mentioning_entity($1)`,
entityID,
)
if err != nil {
return nil, helper.NewError("query", err)
}
defer rows.Close()
var mentions []*model.ChunkMention
for rows.Next() {
mention := &model.ChunkMention{}
err := rows.Scan(
&mention.ChunkID,
&mention.EdgeID,
&mention.EdgeMetadata,
)
if err != nil {
return nil, helper.NewError("scan", err)
}
mentions = append(mentions, mention)
}
err = rows.Err()
if err != nil {
return nil, helper.NewError("rows error", err)
}
return mentions, nil
}
// GetEntity retrieves an entity by ID (alias for SelectEntity for interface compatibility)
func (h *EntitiesDBHandler) GetEntity(ctx context.Context, id string) (*model.Entity, error) {
entityID, err := uuid.Parse(id)
if err != nil {
return nil, helper.NewError("parse uuid", err)
}
return h.SelectEntity(entityID)
}
// GetChunksForEntity retrieves all chunks related to an entity
func (h *EntitiesDBHandler) GetChunksForEntity(ctx context.Context, entityID string) ([]*model.Chunk, error) {
entityUUID, err := uuid.Parse(entityID)
if err != nil {
return nil, helper.NewError("parse uuid", err)
}
// Get edges that connect this entity to chunks
rows, err := h.db.Instance.QueryContext(ctx,
`SELECT DISTINCT c.id, c.document_id, d.rid, c.content, c.path, c.embedding,
c.start_pos, c.end_pos, c.chunk_index, c.metadata, c.created_at
FROM chunks c
LEFT JOIN documents d ON c.document_id = d.id
INNER JOIN edges e ON (
(e.source_entity_id = $1 AND e.target_chunk_id = c.id)
OR (e.target_entity_id = $1 AND e.source_chunk_id = c.id)
)
ORDER BY c.created_at`,
entityUUID,
)
if err != nil {
return nil, helper.NewError("query", err)
}
defer rows.Close()
var chunks []*model.Chunk
for rows.Next() {
chunk := &model.Chunk{}
var embeddingVec interface{}
err := rows.Scan(
&chunk.ID,
&chunk.DocumentID,
&chunk.DocumentRID,
&chunk.Content,
&chunk.Path,
&embeddingVec,
&chunk.StartPos,
&chunk.EndPos,
&chunk.ChunkIndex,
&chunk.Metadata,
&chunk.CreatedAt,
)
if err != nil {
return nil, helper.NewError("scan", err)
}
chunks = append(chunks, chunk)
}
err = rows.Err()
if err != nil {
return nil, helper.NewError("rows error", err)
}
return chunks, nil
}
package database
import (
"context"
"fmt"
"time"
"github.com/siherrmann/grapher/helper"
)
// ChangeIndexType changes the vector index type between HNSW and IVFFlat
// indexType: "hnsw" or "ivfflat"
// params: optional parameters for index creation
// - For HNSW: "m" (int, default 16), "ef_construction" (int, default 64)
// - For IVFFlat: "lists" (int, default 100)
func (h *ChunksDBHandler) ChangeIndexType(ctx context.Context, indexType string, params map[string]interface{}) error {
ctx, cancel := context.WithTimeout(ctx, 60*time.Second)
defer cancel()
// Drop existing index
_, err := h.db.Instance.ExecContext(ctx, `DROP INDEX IF EXISTS idx_chunks_embedding;`)
if err != nil {
return helper.NewError("drop index", err)
}
h.db.Logger.Info("Dropped existing vector index")
// Create new index based on type
var createIndexSQL string
switch indexType {
case "hnsw":
m := 16
efConstruction := 64
if mVal, ok := params["m"].(int); ok {
m = mVal
}
if efVal, ok := params["ef_construction"].(int); ok {
efConstruction = efVal
}
createIndexSQL = fmt.Sprintf(
`CREATE INDEX idx_chunks_embedding ON chunks USING hnsw (embedding vector_cosine_ops) WITH (m = %d, ef_construction = %d);`,
m, efConstruction,
)
case "ivfflat":
lists := 100
if listsVal, ok := params["lists"].(int); ok {
lists = listsVal
}
createIndexSQL = fmt.Sprintf(
`CREATE INDEX idx_chunks_embedding ON chunks USING ivfflat (embedding vector_cosine_ops) WITH (lists = %d);`,
lists,
)
default:
return helper.NewError("change index type", fmt.Errorf("unsupported index type: %s (use 'hnsw' or 'ivfflat')", indexType))
}
// Create the new index
_, err = h.db.Instance.ExecContext(ctx, createIndexSQL)
if err != nil {
return helper.NewError("create index", err)
}
h.db.Logger.Info(fmt.Sprintf("Created %s index with params: %v", indexType, params))
return nil
}
package grapher
import (
"context"
"fmt"
"log/slog"
"os"
"github.com/google/uuid"
"github.com/siherrmann/grapher/core/pipeline"
"github.com/siherrmann/grapher/core/retrieval"
"github.com/siherrmann/grapher/database"
"github.com/siherrmann/grapher/helper"
"github.com/siherrmann/grapher/model"
loadSql "github.com/siherrmann/grapher/sql"
)
// Grapher provides a unified interface to all database handlers
type Grapher struct {
DB *helper.Database
Chunks *database.ChunksDBHandler
Documents *database.DocumentsDBHandler
Edges *database.EdgesDBHandler
Entities *database.EntitiesDBHandler
Pipeline *pipeline.Pipeline // Optional chunking pipeline
Engine *retrieval.Engine // Retrieval engine for hybrid search
// Logging
log *slog.Logger
}
// NewGrapher creates a new Grapher instance with all handlers initialized
func NewGrapher(config *helper.DatabaseConfiguration, embeddingDim int) (*Grapher, error) {
// Logger
opts := helper.PrettyHandlerOptions{
SlogOpts: slog.HandlerOptions{
Level: slog.LevelInfo,
},
}
logger := slog.New(helper.NewPrettyHandler(os.Stdout, opts))
// Initialize database
db := helper.NewDatabase("grapher", config, logger)
err := loadSql.Init(db.Instance)
if err != nil {
return nil, helper.NewError("initialize database extensions", err)
}
// Create all handlers in the correct order (documents first, then chunks)
// force=false to not reload if functions already exist
documents, err := database.NewDocumentsDBHandler(db, false)
if err != nil {
return nil, helper.NewError("create documents handler", err)
}
edges, err := database.NewEdgesDBHandler(db, false)
if err != nil {
return nil, helper.NewError("create edges handler", err)
}
chunks, err := database.NewChunksDBHandler(db, edges, embeddingDim, false)
if err != nil {
return nil, helper.NewError("create chunks handler", err)
}
entities, err := database.NewEntitiesDBHandler(db, false)
if err != nil {
return nil, helper.NewError("create entities handler", err)
}
// Create retrieval engine with database handlers
engine := retrieval.NewEngine(chunks, edges, entities)
return &Grapher{
DB: db,
Chunks: chunks,
Documents: documents,
Edges: edges,
Entities: entities,
Engine: engine,
log: logger,
}, nil
}
// Close closes the database connection
func (g *Grapher) Close() error {
if g.DB != nil && g.DB.Instance != nil {
return g.DB.Instance.Close()
}
return nil
}
// SetPipeline sets the chunking pipeline for document processing
func (g *Grapher) SetPipeline(pipeline *pipeline.Pipeline) {
g.Pipeline = pipeline
}
// UseDefaultPipeline sets up the default semantic chunking and embedding pipeline
// This uses DefaultChunker with 500 char max chunks and 0.7 similarity threshold,
// DefaultEmbedder with the all-MiniLM-L6-v2 model (384 dimensions),
// DefaultEntityExtractor with distilbert-NER for entity recognition,
// and DefaultRelationExtractor with distilbert-NER for citation and reference detection
func (g *Grapher) UseDefaultPipeline() error {
chunker := pipeline.DefaultChunker(500, 0.7)
embedder, err := pipeline.DefaultEmbedder()
if err != nil {
return helper.NewError("create default embedder", err)
}
entityExtractor, err := pipeline.DefaultEntityExtractor()
if err != nil {
return helper.NewError("create default entity extractor", err)
}
relationExtractor, err := pipeline.DefaultRelationExtractor()
if err != nil {
return helper.NewError("create default relation extractor", err)
}
g.Pipeline = pipeline.NewPipeline(chunker, embedder)
g.Pipeline.SetEntityExtractor(entityExtractor)
g.Pipeline.SetRelationExtractor(relationExtractor)
return nil
}
// ProcessAndInsertDocument processes a document by:
// 1. Inserting the document metadata (without content)
// 2. Processing the content into chunks using the pipeline
// 3. Inserting all chunks with the document ID
// 4. Extracting and inserting entities (if entity extractor is configured)
// 5. Extracting and inserting relations/edges (if relation extractor is configured)
// The document's Content field is used for processing but not stored in the database.
// Returns the number of chunks inserted and any error encountered.
func (g *Grapher) ProcessAndInsertDocument(doc *model.Document) (int, error) {
if g.Pipeline == nil {
return 0, helper.NewError("process document", fmt.Errorf("pipeline not set, use SetPipeline() first"))
}
if doc.Content == "" {
return 0, helper.NewError("process document", fmt.Errorf("document content is empty"))
}
// Store content temporarily and clear it before DB insert
content := doc.Content
doc.Content = ""
// Insert document metadata
if err := g.Documents.InsertDocument(doc); err != nil {
return 0, helper.NewError("insert document", err)
}
g.log.Info("Inserted document", slog.String("document_id", doc.RID.String()), slog.String("title", doc.Title))
// Process content with entity and relation extraction
result, err := g.Pipeline.ProcessWithExtraction(content, fmt.Sprintf("doc_%s", doc.RID.String()))
if err != nil {
return 0, helper.NewError("process chunks", err)
}
g.log.Info("Processed document into chunks",
slog.Int("num_chunks", len(result.Chunks)),
slog.Int("num_entities", len(result.Entities)),
slog.Int("num_relations", len(result.Relations)),
slog.String("document_id", doc.RID.String()))
// Insert all chunks and build a path-to-ID mapping
chunkPathToID := make(map[string]uuid.UUID)
for i, chunk := range result.Chunks {
chunk.DocumentID = doc.ID
if err := g.Chunks.InsertChunk(chunk); err != nil {
return i, helper.NewError(fmt.Sprintf("insert chunk %d", i), err)
}
chunkPathToID[chunk.Path] = chunk.ID
}
// Insert entities
if len(result.Entities) > 0 {
for _, entity := range result.Entities {
if err := g.Entities.InsertEntity(entity); err != nil {
g.log.Error("Failed to insert entity", slog.String("entity", entity.Name), slog.String("error", err.Error()))
// Continue processing other entities even if one fails
}
}
g.log.Info("Inserted entities", slog.Int("count", len(result.Entities)))
}
// Insert relations/edges
if len(result.Relations) > 0 {
for _, edge := range result.Relations {
// For reference edges without entity IDs, link to the source chunk
if edge.SourceEntityID == nil && edge.SourceChunkID == nil {
// Get chunk ID from extracted_from metadata
if extractedFrom, ok := edge.Metadata["extracted_from"].(string); ok {
if chunkID, found := chunkPathToID[extractedFrom]; found {
edge.SourceChunkID = &chunkID
}
}
}
// Skip edges that don't have both source and target
// (e.g., citations to external documents not in our database)
hasSource := edge.SourceChunkID != nil || edge.SourceEntityID != nil
hasTarget := edge.TargetChunkID != nil || edge.TargetEntityID != nil
if !hasSource || !hasTarget {
continue // Skip this edge
}
if err := g.Edges.InsertEdge(edge); err != nil {
g.log.Error("Failed to insert edge", slog.String("type", string(edge.EdgeType)), slog.String("error", err.Error()))
// Continue processing other edges even if one fails
}
}
g.log.Info("Inserted relations", slog.Int("count", len(result.Relations)))
}
return len(result.Chunks), nil
}
// Search performs vector similarity search
func (g *Grapher) Search(ctx context.Context, query string, config *model.QueryConfig) ([]*model.RetrievalResult, error) {
if g.Engine == nil {
return nil, helper.NewError("vector search", fmt.Errorf("retrieval engine not initialized"))
}
if g.Pipeline == nil || g.Pipeline.Embedder == nil {
return nil, helper.NewError("vector search", fmt.Errorf("pipeline with embedder not set, use SetPipeline() first"))
}
// Generate embedding from query
embedding, err := g.Pipeline.Embedder(query)
if err != nil {
return nil, helper.NewError("generate embedding", err)
}
return g.Engine.VectorRetrieve(ctx, embedding, config)
}
// ContextualSearch performs contextual retrieval (vector + neighbors + hierarchy)
func (g *Grapher) ContextualSearch(ctx context.Context, query string, config *model.QueryConfig) ([]*model.RetrievalResult, error) {
if g.Pipeline == nil || g.Pipeline.Embedder == nil {
return nil, helper.NewError("contextual search", fmt.Errorf("pipeline with embedder not set, use SetPipeline() first"))
}
// Generate embedding from query
embedding, err := g.Pipeline.Embedder(query)
if err != nil {
return nil, helper.NewError("generate embedding", err)
}
strategy := retrieval.NewContextualStrategy(g.Engine)
return strategy.Retrieve(ctx, embedding, config)
}
// MultiHopSearch performs multi-hop graph traversal retrieval
func (g *Grapher) MultiHopSearch(ctx context.Context, query string, config *model.QueryConfig) ([]*model.RetrievalResult, error) {
if g.Pipeline == nil || g.Pipeline.Embedder == nil {
return nil, helper.NewError("multi-hop search", fmt.Errorf("pipeline with embedder not set, use SetPipeline() first"))
}
// Generate embedding from query
embedding, err := g.Pipeline.Embedder(query)
if err != nil {
return nil, helper.NewError("generate embedding", err)
}
strategy := retrieval.NewMultiHopStrategy(g.Engine)
return strategy.Retrieve(ctx, embedding, config)
}
// HybridSearch performs fully configurable hybrid retrieval
func (g *Grapher) HybridSearch(ctx context.Context, query string, config *model.QueryConfig) ([]*model.RetrievalResult, error) {
if g.Pipeline == nil || g.Pipeline.Embedder == nil {
return nil, helper.NewError("hybrid search", fmt.Errorf("pipeline with embedder not set, use SetPipeline() first"))
}
// Generate embedding from query
embedding, err := g.Pipeline.Embedder(query)
if err != nil {
return nil, helper.NewError("generate embedding", err)
}
strategy := retrieval.NewHybridStrategy(g.Engine)
return strategy.Retrieve(ctx, embedding, config)
}
// DocumentScopedSearch performs hybrid search within specific documents only
// This is optimized for single or multi-document Q&A by filtering at the database level
func (g *Grapher) DocumentScopedSearch(ctx context.Context, query string, documentRIDs []uuid.UUID, config *model.QueryConfig) ([]*model.RetrievalResult, error) {
if g.Pipeline == nil || g.Pipeline.Embedder == nil {
return nil, helper.NewError("document scoped search", fmt.Errorf("pipeline with embedder not set, use SetPipeline() first"))
}
if len(documentRIDs) == 0 {
return nil, helper.NewError("document scoped search", fmt.Errorf("at least one document RID must be provided"))
}
// Generate embedding from query
embedding, err := g.Pipeline.Embedder(query)
if err != nil {
return nil, helper.NewError("generate embedding", err)
}
// Set document filter in config
if config == nil {
config = &model.QueryConfig{}
}
config.DocumentRIDs = documentRIDs
strategy := retrieval.NewHybridStrategy(g.Engine)
return strategy.Retrieve(ctx, embedding, config)
}
// EntityCentricSearch performs entity-centric retrieval
func (g *Grapher) EntityCentricSearch(ctx context.Context, entityID uuid.UUID, config *model.QueryConfig) ([]*model.RetrievalResult, error) {
strategy := retrieval.NewEntityCentricStrategy(g.Engine, g.Entities)
return strategy.Retrieve(ctx, entityID, config)
}
// BFSTraversal performs breadth-first search from a chunk
func (g *Grapher) BFSTraversal(ctx context.Context, sourceID uuid.UUID, maxHops int, edgeTypes []model.EdgeType, followBidirectional bool) ([]*retrieval.TraversalResult, error) {
return g.Engine.BFS(ctx, sourceID, maxHops, edgeTypes, followBidirectional)
}
// DFSTraversal performs depth-first search from a chunk
func (g *Grapher) DFSTraversal(ctx context.Context, sourceID uuid.UUID, maxHops int, edgeTypes []model.EdgeType, followBidirectional bool) ([]*retrieval.TraversalResult, error) {
return g.Engine.DFS(ctx, sourceID, maxHops, edgeTypes, followBidirectional)
}
// ChangeIndexType changes the vector index type between HNSW and IVFFlat
func (g *Grapher) ChangeIndexType(ctx context.Context, indexType string, params map[string]interface{}) error {
return g.Chunks.ChangeIndexType(ctx, indexType, params)
}
package helper
import (
"context"
"database/sql"
"fmt"
"log"
"log/slog"
"os"
"strconv"
"strings"
"sync"
"time"
_ "github.com/joho/godotenv/autoload"
"github.com/lib/pq"
)
// Database represents a service that interacts with a database.
type Database struct {
Name string
Logger *slog.Logger
Instance *sql.DB
}
func NewDatabase(name string, dbConfig *DatabaseConfiguration, logger *slog.Logger) *Database {
if dbConfig != nil {
db := &Database{Name: name, Logger: logger}
db.ConnectToDatabase(dbConfig, logger)
if db.Instance == nil {
panic("error connecting to database")
}
return db
} else {
return &Database{
Name: name,
Logger: logger,
Instance: nil,
}
}
}
func NewDatabaseWithDB(name string, dbConnnection *sql.DB, logger *slog.Logger) *Database {
return &Database{
Name: name,
Logger: logger,
Instance: dbConnnection,
}
}
type DatabaseConfiguration struct {
Host string
Port string
Database string
Username string
Password string
Schema string
SSLMode string
WithTableDrop bool
}
// NewDatabaseConfiguration creates a new DatabaseConfiguration instance.
// It reads the database configuration from environment variables.
// It returns a pointer to the new DatabaseConfiguration instance or an error if any required environment variable is not set.
func NewDatabaseConfiguration() (*DatabaseConfiguration, error) {
config := &DatabaseConfiguration{
Host: os.Getenv("GRAPHER_DB_HOST"),
Port: os.Getenv("GRAPHER_DB_PORT"),
Database: os.Getenv("GRAPHER_DB_DATABASE"),
Username: os.Getenv("GRAPHER_DB_USERNAME"),
Password: os.Getenv("GRAPHER_DB_PASSWORD"),
Schema: os.Getenv("GRAPHER_DB_SCHEMA"),
SSLMode: os.Getenv("GRAPHER_DB_SSLMODE"),
WithTableDrop: os.Getenv("GRAPHER_DB_WITH_TABLE_DROP") == "true",
}
if len(strings.TrimSpace(config.Host)) == 0 || len(strings.TrimSpace(config.Port)) == 0 || len(strings.TrimSpace(config.Database)) == 0 || len(strings.TrimSpace(config.Username)) == 0 || len(strings.TrimSpace(config.Password)) == 0 || len(strings.TrimSpace(config.Schema)) == 0 {
return nil, fmt.Errorf("GRAPHER_DB_HOST, GRAPHER_DB_PORT, GRAPHER_DB_DATABASE, GRAPHER_DB_USERNAME, GRAPHER_DB_PASSWORD and GRAPHER_DB_SCHEMA environment variables must be set")
}
return config, nil
}
func (d *DatabaseConfiguration) DatabaseConnectionString() string {
sslMode := "require"
if len(strings.TrimSpace(d.SSLMode)) > 0 {
sslMode = d.SSLMode
}
return fmt.Sprintf("postgres://%s:%s@%s:%s/%s?application_name=grapher&sslmode=%s&search_path=%s", d.Username, d.Password, d.Host, d.Port, d.Database, sslMode, d.Schema)
}
// Internal function for the service creation to connect to a database.
// DatabaseConfiguration must contain uri, username and password.
// It initializes the database connection and sets the Instance field of the Database struct.
func (d *Database) ConnectToDatabase(dbConfig *DatabaseConfiguration, logger *slog.Logger) {
if len(strings.TrimSpace(dbConfig.Host)) == 0 || len(strings.TrimSpace(dbConfig.Port)) == 0 || len(strings.TrimSpace(dbConfig.Database)) == 0 || len(strings.TrimSpace(dbConfig.Username)) == 0 || len(strings.TrimSpace(dbConfig.Password)) == 0 || len(strings.TrimSpace(dbConfig.Schema)) == 0 {
panic("database configuration must contain uri, username and password")
}
var connectOnce sync.Once
var db *sql.DB
connectOnce.Do(func() {
dsn, err := pq.ParseURL(dbConfig.DatabaseConnectionString())
if err != nil {
log.Panicf("error parsing database connection string: %s", err.Error())
}
base, err := pq.NewConnector(dsn)
if err != nil {
log.Panic(err)
}
db = sql.OpenDB(pq.ConnectorWithNoticeHandler(base, func(notice *pq.Error) {
// log.Printf("Notice sent: %s", notice.Message)
}))
db.SetMaxOpenConns(0)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
_, err = db.ExecContext(
ctx,
"CREATE EXTENSION IF NOT EXISTS pg_trgm;",
)
if err != nil {
logger.Error(err.Error())
}
pingErr := db.Ping()
if pingErr != nil {
log.Panicf("error connecting to database: %s", pingErr.Error())
}
logger.Info("Connected to db")
})
d.Instance = db
}
// CheckTableExistance checks if a table with the specified name exists in the database.
// It queries the information_schema.tables to check for the existence of the table.
// It returns true if the table exists, false otherwise, and an error if the query fails.
func (d *Database) CheckTableExistance(tableName string) (bool, error) {
exists := false
tableNameQuoted := pq.QuoteLiteral(tableName)
row := d.Instance.QueryRow(
fmt.Sprintf(`
SELECT EXISTS (
SELECT 1
FROM information_schema.tables
WHERE table_name = %s
) AS table_existence`,
tableNameQuoted,
),
)
err := row.Scan(
&exists,
)
if err != nil {
return false, err
}
return exists, nil
}
// CreateIndex creates an index on the specified column of the specified table.
// It uses the PostgreSQL CREATE INDEX statement to create the index.
// If the index already exists, it will not create a new one.
// It returns an error if the index creation fails.
func (d *Database) CreateIndex(tableName string, columnName string) error {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
tableNameQuoted := pq.QuoteIdentifier(tableName)
indexQuoted := pq.QuoteIdentifier("idx_" + tableName + "_" + columnName)
columnNameQuoted := pq.QuoteIdentifier(columnName)
_, err := d.Instance.ExecContext(
ctx,
fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s ON %s(%s)", indexQuoted, tableNameQuoted, columnNameQuoted),
)
if err != nil {
return fmt.Errorf("error creating %s index: %#v", indexQuoted, err)
}
return nil
}
// CreateIndexes creates indexes on the specified columns of the specified table.
// It iterates over the column names and calls CreateIndex for each one.
// It returns an error if any of the index creations fail.
func (d *Database) CreateIndexes(tableName string, columnNames ...string) error {
for _, columnName := range columnNames {
err := d.CreateIndex(tableName, columnName)
if err != nil {
return err
}
}
return nil
}
// CreateCombinedIndex creates a combined index on the specified columns of the specified table.
// It uses the PostgreSQL CREATE INDEX statement to create the index.
// If the index already exists, it will not create a new one.
// It returns an error if the index creation fails.
func (d *Database) CreateCombinedIndex(tableName string, columnName1 string, columnName2 string) error {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
tableNameQuoted := pq.QuoteIdentifier(tableName)
indexQuoted := pq.QuoteIdentifier("idx_" + tableName + "_" + columnName1 + "_" + columnName2)
columnName1Quoted := pq.QuoteIdentifier(columnName1)
columnName2Quoted := pq.QuoteIdentifier(columnName2)
_, err := d.Instance.ExecContext(
ctx,
fmt.Sprintf(`CREATE INDEX IF NOT EXISTS %s ON %s (%s, %s)`, indexQuoted, tableNameQuoted, columnName1Quoted, columnName2Quoted),
)
if err != nil {
return fmt.Errorf("error creating %s index: %#v", indexQuoted, err)
}
return nil
}
// CreateUniqueCombinedIndex creates a unique combined index on the specified columns of the specified table.
// It uses the PostgreSQL CREATE UNIQUE INDEX statement to create the index.
// If the index already exists, it will not create a new one.
// It returns an error if the index creation fails.
func (d *Database) CreateUniqueCombinedIndex(tableName string, columnName1 string, columnName2 string) error {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
tableNameQuoted := pq.QuoteIdentifier(tableName)
indexQuoted := pq.QuoteIdentifier("idx_" + tableName + "_" + columnName1 + "_" + columnName2)
columnName1Quoted := pq.QuoteIdentifier(columnName1)
columnName2Quoted := pq.QuoteIdentifier(columnName2)
_, err := d.Instance.ExecContext(
ctx,
fmt.Sprintf(`CREATE UNIQUE INDEX IF NOT EXISTS %s ON %s (%s, %s)`, indexQuoted, tableNameQuoted, columnName1Quoted, columnName2Quoted),
)
if err != nil {
return fmt.Errorf("error creating %s index: %#v", indexQuoted, err)
}
return nil
}
// DropIndex drops the index on the specified table and column.
// It uses the PostgreSQL DROP INDEX statement to drop the index.
// If the index does not exist, it will not return an error.
// It returns an error if the index dropping fails.
func (d *Database) DropIndex(tableName string, jsonMapKey string) error {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
indexQuoted := pq.QuoteIdentifier("idx_" + tableName + "_" + jsonMapKey)
_, err := d.Instance.ExecContext(
ctx,
fmt.Sprintf(`DROP INDEX IF EXISTS %s;`, indexQuoted),
)
if err != nil {
return fmt.Errorf("error dropping %s index: %#v", indexQuoted, err)
}
return nil
}
// Health checks the health of the database connection by pinging the database.
// It returns a map with keys indicating various health statistics.
func (d *Database) Health() map[string]string {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
stats := make(map[string]string)
// Ping the database
err := d.Instance.PingContext(ctx)
if err != nil {
stats["status"] = "down"
stats["error"] = fmt.Sprintf("db down: %v", err)
log.Panicf("db down: %v", err) // Log the error and terminate the program
return stats
}
// Database is up, add more statistics
stats["status"] = "up"
stats["message"] = "It's healthy"
// Get database stats (like open connections, in use, idle, etc.)
dbStats := d.Instance.Stats()
stats["open_connections"] = strconv.Itoa(dbStats.OpenConnections)
stats["in_use"] = strconv.Itoa(dbStats.InUse)
stats["idle"] = strconv.Itoa(dbStats.Idle)
stats["wait_count"] = strconv.FormatInt(dbStats.WaitCount, 10)
stats["wait_duration"] = dbStats.WaitDuration.String()
stats["max_idle_closed"] = strconv.FormatInt(dbStats.MaxIdleClosed, 10)
stats["max_lifetime_closed"] = strconv.FormatInt(dbStats.MaxLifetimeClosed, 10)
// Evaluate stats to provide a health message
if dbStats.OpenConnections > 40 { // Assuming 50 is the max for this example
stats["message"] = "The database is experiencing heavy load."
}
if dbStats.WaitCount > 1000 {
stats["message"] = "The database has a high number of wait events, indicating potential bottlenecks."
}
if dbStats.MaxIdleClosed > int64(dbStats.OpenConnections)/2 {
stats["message"] = "Many idle connections are being closed, consider revising the connection pool settings."
}
if dbStats.MaxLifetimeClosed > int64(dbStats.OpenConnections)/2 {
stats["message"] = "Many connections are being closed due to max lifetime, consider increasing max lifetime or revising the connection usage pattern."
}
return stats
}
// Close closes the database connection.
// It logs a message indicating the disconnection from the specific database.
// If the connection is successfully closed, it returns nil.
// If an error occurs while closing the connection, it returns the error.
func (d *Database) Close() error {
log.Printf("Disconnected from database: %v", d.Instance)
return d.Instance.Close()
}
package helper
import (
"context"
"fmt"
"log/slog"
"net/url"
"testing"
"time"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/wait"
)
const (
dbName = "database"
dbUser = "user"
dbPwd = "password"
)
// MustStartPostgresContainer starts a PostgreSQL container for testing purposes.
// It uses the timescale/timescaledb image with PostgreSQL 17.
// It returns a function to terminate the container, the port on which the database is accessible,
// and an error if the container could not be started.
func MustStartPostgresContainer() (func(ctx context.Context, opts ...testcontainers.TerminateOption) error, string, error) {
ctx := context.Background()
pgContainer, err := postgres.Run(
ctx,
"timescale/timescaledb:latest-pg17",
postgres.WithDatabase(dbName),
postgres.WithUsername(dbUser),
postgres.WithPassword(dbPwd),
testcontainers.WithWaitStrategy(
wait.ForLog("database system is ready to accept connections").
WithOccurrence(2).WithStartupTimeout(5*time.Second),
),
)
if err != nil {
return nil, "", fmt.Errorf("error starting postgres container: %w", err)
}
connStr, err := pgContainer.ConnectionString(ctx, "sslmode=disable")
if err != nil {
return nil, "", fmt.Errorf("error getting connection string: %w", err)
}
u, err := url.Parse(connStr)
if err != nil {
return nil, "", fmt.Errorf("error parsing connection string: %v", err)
}
return pgContainer.Terminate, u.Port(), err
}
// NewTestDatabase creates a new Database instance for testing purposes.
// It initializes the database with the provided configuration and the name "test_db".
// It returns a pointer to the new Database instance.
func NewTestDatabase(config *DatabaseConfiguration) *Database {
return NewDatabase(
"test_db",
config,
slog.Default(),
)
}
// SetTestDatabaseConfigEnvs sets the environment variables for the test database configuration.
// It sets the host, port, database name, username, password, schema,
// and table drop options for the test database.
func SetTestDatabaseConfigEnvs(t *testing.T, port string) {
t.Setenv("GRAPHER_DB_HOST", "localhost")
t.Setenv("GRAPHER_DB_PORT", port)
t.Setenv("GRAPHER_DB_DATABASE", dbName)
t.Setenv("GRAPHER_DB_USERNAME", dbUser)
t.Setenv("GRAPHER_DB_PASSWORD", dbPwd)
t.Setenv("GRAPHER_DB_SCHEMA", "public")
t.Setenv("GRAPHER_DB_SSLMODE", "disable")
t.Setenv("GRAPHER_DB_WITH_TABLE_DROP", "true")
}
package helper
import (
"fmt"
"path"
"runtime"
"strings"
)
type Error struct {
Original error
Trace []string
}
func (e Error) Error() string {
return e.Original.Error() + " | Trace: " + fmt.Sprint(strings.Join(e.Trace, ", "))
}
func NewError(trace string, original error) Error {
pc, _, _, ok := runtime.Caller(1)
details := runtime.FuncForPC(pc)
if ok && details != nil {
functionName := path.Base(details.Name())
trace = functionName + " - " + trace
}
if v, ok := original.(Error); ok {
return Error{
Original: v.Original,
Trace: append(v.Trace, trace),
}
}
return Error{
Original: original,
Trace: []string{trace},
}
}
package helper
import (
"fmt"
"os"
"path/filepath"
"github.com/knights-analytics/hugot"
)
// PrepareModel downloads the model if it doesn't exist and returns the model path
func PrepareModel(modelName string, onnxFilePath string) (string, error) {
modelDir := "./models"
// Sanitize model name for directory (replace / with _)
sanitizedName := filepath.Base(modelName)
if filepath.Dir(modelName) != "." {
sanitizedName = filepath.Dir(modelName) + "_" + filepath.Base(modelName)
}
modelPath := filepath.Join(modelDir, sanitizedName)
// Check if model exists, if not download it
if _, err := os.Stat(modelPath); os.IsNotExist(err) {
if err := os.MkdirAll(modelDir, 0750); err != nil {
return "", fmt.Errorf("failed to create model directory: %w", err)
}
downloadOptions := hugot.NewDownloadOptions()
if onnxFilePath != "" {
downloadOptions.OnnxFilePath = onnxFilePath
}
downloadedPath, err := hugot.DownloadModel(modelName, modelDir, downloadOptions)
if err != nil {
return "", fmt.Errorf("failed to download model: %w", err)
}
modelPath = downloadedPath
}
return modelPath, nil
}
package helper
import (
"context"
"encoding/json"
"io"
"log"
"log/slog"
"github.com/fatih/color"
)
type PrettyHandlerOptions struct {
SlogOpts slog.HandlerOptions
}
type PrettyHandler struct {
slog.Handler
l *log.Logger
}
func (h *PrettyHandler) Handle(ctx context.Context, r slog.Record) error {
level := r.Level.String() + ":"
switch r.Level {
case slog.LevelDebug:
level = color.MagentaString(level)
case slog.LevelInfo:
level = color.BlueString(level)
case slog.LevelWarn:
level = color.YellowString(level)
case slog.LevelError:
level = color.RedString(level)
}
fields := make(map[string]interface{}, r.NumAttrs())
r.Attrs(func(a slog.Attr) bool {
fields[a.Key] = a.Value.Any()
return true
})
b, err := json.MarshalIndent(fields, "", " ")
if err != nil {
return err
}
timeStr := r.Time.Format("[15:05:05.000]")
msg := color.CyanString(r.Message)
h.l.Println(timeStr, level, msg, color.WhiteString(string(b)))
return nil
}
func NewPrettyHandler(
out io.Writer,
opts PrettyHandlerOptions,
) *PrettyHandler {
h := &PrettyHandler{
Handler: slog.NewJSONHandler(out, &opts.SlogOpts),
l: log.New(out, "", 0),
}
return h
}
package model
import "github.com/google/uuid"
// QueryConfig represents configuration for a retrieval query
type QueryConfig struct {
// Vector search parameters
TopK int `json:"top_k"`
SimilarityThreshold float64 `json:"similarity_threshold,omitempty"`
// Document filtering
DocumentRIDs []uuid.UUID `json:"document_rids,omitempty"` // Filter by specific documents
// Graph traversal parameters
MaxHops int `json:"max_hops,omitempty"`
EdgeTypes []EdgeType `json:"edge_types,omitempty"` // Filter by edge types
FollowBidirectional bool `json:"follow_bidirectional"`
// Ltree parameters
IncludeAncestors bool `json:"include_ancestors"`
IncludeDescendants bool `json:"include_descendants"`
IncludeSiblings bool `json:"include_siblings"`
// Ranking parameters
VectorWeight float64 `json:"vector_weight"` // Weight for similarity score
GraphWeight float64 `json:"graph_weight"` // Weight for graph distance
HierarchyWeight float64 `json:"hierarchy_weight"` // Weight for hierarchy distance
}
// DefaultQueryConfig returns a sensible default configuration
func DefaultQueryConfig() QueryConfig {
return QueryConfig{
TopK: 5,
SimilarityThreshold: 0.7,
MaxHops: 2,
EdgeTypes: nil, // All types
FollowBidirectional: true,
IncludeAncestors: false,
IncludeDescendants: false,
IncludeSiblings: true,
VectorWeight: 0.6,
GraphWeight: 0.3,
HierarchyWeight: 0.1,
}
}
package model
import (
"os"
"path/filepath"
"time"
"github.com/google/uuid"
)
// Document represents a source document
type Document struct {
ID int64 `json:"id"`
RID uuid.UUID `json:"rid"`
Title string `json:"title"`
Source string `json:"source,omitempty"`
Content string `json:"content,omitempty" db:"-"` // Temporary field for processing, not stored in DB
Metadata Metadata `json:"metadata,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// NewDocumentFromFile reads a file and creates a Document with the file content
// The title defaults to the filename, and source to the file path
func NewDocumentFromFile(filePath string, metadata Metadata) (*Document, error) {
// Clean the path to prevent directory traversal
cleanPath := filepath.Clean(filePath)
// #nosec G304 - This function intentionally reads user-specified files as part of document ingestion
content, err := os.ReadFile(cleanPath)
if err != nil {
return nil, err
}
// Get filename without extension for default title
filename := filepath.Base(filePath)
title := filename[:len(filename)-len(filepath.Ext(filename))]
if title == "" {
title = filename
}
return &Document{
Title: title,
Source: filePath,
Content: string(content),
Metadata: metadata,
}, nil
}
package model
import (
"database/sql/driver"
"encoding/json"
"errors"
"github.com/siherrmann/grapher/helper"
)
// Metadata represents JSONB metadata stored in PostgreSQL
type Metadata map[string]interface{}
// Value implements the driver.Valuer interface for database storage
func (m Metadata) Value() (driver.Value, error) {
return m.Marshal()
}
// Scan implements the sql.Scanner interface for database retrieval
func (m *Metadata) Scan(value interface{}) error {
return m.Unmarshal(value)
}
// Marshal converts Metadata to JSON bytes
func (m Metadata) Marshal() ([]byte, error) {
return json.Marshal(m)
}
// Unmarshal converts JSON bytes or Metadata to Metadata
func (m *Metadata) Unmarshal(value interface{}) error {
if value == nil {
*m = Metadata{}
return nil
}
if s, ok := value.(Metadata); ok {
*m = Metadata(s)
return nil
}
b, ok := value.([]byte)
if !ok {
return helper.NewError("byte assertion", errors.New("type assertion to []byte failed"))
}
return json.Unmarshal(b, m)
}
package sql
import (
"database/sql"
_ "embed"
"fmt"
"log"
)
//go:embed init.sql
var initSQL string
//go:embed chunks.sql
var chunksSQL string
//go:embed documents.sql
var documentsSQL string
//go:embed edges.sql
var edgesSQL string
//go:embed entities.sql
var entitiesSQL string
// Function lists for verification
var ChunksFunctions = []string{
"init_chunks",
"insert_chunk",
"select_chunk",
"select_chunks_by_document",
"select_chunks_by_path_descendant",
"select_chunks_by_path_ancestor",
"select_chunks_by_similarity",
"select_chunks_by_similarity_with_context",
"delete_chunk",
"update_chunk_embedding",
}
var DocumentsFunctions = []string{
"init_documents",
"insert_document",
"select_document",
"select_all_documents",
"search_documents",
"update_document",
"delete_document",
}
var EdgesFunctions = []string{
"init_edges",
"insert_edge",
"select_edge",
"select_edges_from_chunk",
"select_edges_to_chunk",
"select_edges_connected_to_chunk",
"select_edges_from_entity",
"select_edges_to_entity",
"delete_edge",
"update_edge_weight",
"traverse_bfs_from_chunk",
}
var EntitiesFunctions = []string{
"init_entities",
"insert_entity",
"select_entity",
"select_entity_by_name",
"search_entities",
"select_entities_by_type",
"delete_entity",
"update_entity_metadata",
"select_chunks_mentioning_entity",
}
// Init intializes db extensions
func Init(db *sql.DB) error {
_, err := db.Exec(initSQL)
if err != nil {
return fmt.Errorf("error executing schema SQL: %w", err)
}
log.Println("Database extensions initialized successfully")
return nil
}
// LoadChunksSql loads chunk-related SQL functions
func LoadChunksSql(db *sql.DB, force bool) error {
if !force {
exist, err := checkFunctions(db, ChunksFunctions)
if err != nil {
return fmt.Errorf("error checking existing chunks functions: %w", err)
}
if exist {
return nil
}
}
_, err := db.Exec(chunksSQL)
if err != nil {
return fmt.Errorf("error executing chunks SQL: %w", err)
}
exist, err := checkFunctions(db, ChunksFunctions)
if err != nil {
return fmt.Errorf("error checking existing functions: %w", err)
}
if !exist {
return fmt.Errorf("not all required SQL functions were created")
}
log.Println("SQL chunks functions loaded successfully")
return nil
}
// LoadDocumentsSql loads document-related SQL functions
func LoadDocumentsSql(db *sql.DB, force bool) error {
if !force {
exist, err := checkFunctions(db, DocumentsFunctions)
if err != nil {
return fmt.Errorf("error checking existing documents functions: %w", err)
}
if exist {
return nil
}
}
_, err := db.Exec(documentsSQL)
if err != nil {
return fmt.Errorf("error executing documents SQL: %w", err)
}
exist, err := checkFunctions(db, DocumentsFunctions)
if err != nil {
return fmt.Errorf("error checking existing functions: %w", err)
}
if !exist {
return fmt.Errorf("not all required SQL functions were created")
}
log.Println("SQL documents functions loaded successfully")
return nil
}
// LoadEdgesSql loads edge-related SQL functions
func LoadEdgesSql(db *sql.DB, force bool) error {
if !force {
exist, err := checkFunctions(db, EdgesFunctions)
if err != nil {
return fmt.Errorf("error checking existing edges functions: %w", err)
}
if exist {
return nil
}
}
_, err := db.Exec(edgesSQL)
if err != nil {
return fmt.Errorf("error executing edges SQL: %w", err)
}
exist, err := checkFunctions(db, EdgesFunctions)
if err != nil {
return fmt.Errorf("error checking existing functions: %w", err)
}
if !exist {
return fmt.Errorf("not all required SQL functions were created")
}
log.Println("SQL edges functions loaded successfully")
return nil
}
// LoadEntitiesSql loads entity-related SQL functions
func LoadEntitiesSql(db *sql.DB, force bool) error {
if !force {
exist, err := checkFunctions(db, EntitiesFunctions)
if err != nil {
return fmt.Errorf("error checking existing entities functions: %w", err)
}
if exist {
return nil
}
}
_, err := db.Exec(entitiesSQL)
if err != nil {
return fmt.Errorf("error executing entities SQL: %w", err)
}
exist, err := checkFunctions(db, EntitiesFunctions)
if err != nil {
return fmt.Errorf("error checking existing functions: %w", err)
}
if !exist {
return fmt.Errorf("not all required SQL functions were created")
}
log.Println("SQL entities functions loaded successfully")
return nil
}
// LoadAllSql loads all SQL functions
func LoadAllSql(db *sql.DB, force bool) error {
if err := LoadChunksSql(db, force); err != nil {
return err
}
if err := LoadDocumentsSql(db, force); err != nil {
return err
}
if err := LoadEdgesSql(db, force); err != nil {
return err
}
if err := LoadEntitiesSql(db, force); err != nil {
return err
}
return nil
}
// checkFunctions verifies that all required functions exist in the database
func checkFunctions(db *sql.DB, sqlFunctions []string) (bool, error) {
var allExist bool
for _, f := range sqlFunctions {
err := db.QueryRow(
`SELECT EXISTS(SELECT 1 FROM pg_proc WHERE proname = $1);`,
f,
).Scan(&allExist)
if err != nil {
return false, fmt.Errorf("error checking existence of function %s: %w", f, err)
}
if !allExist {
log.Printf("Function %s does not exist", f)
break
}
}
return allExist, nil
}