package core
import (
"context"
"fmt"
"log/slog"
"os"
"path/filepath"
"strings"
"time"
"github.com/knights-analytics/hugot"
"github.com/knights-analytics/hugot/backends"
"github.com/knights-analytics/hugot/pipelines"
)
type Engine interface {
Generate(ctx context.Context, prompt string, maxTokens int) (string, error)
GenerateStream(ctx context.Context, prompt string, maxTokens int) (chan string, chan error, error)
ExtractEmbeddings(ctx context.Context, input []string) ([][]float32, error)
CountTokens(text string, isEmbedding bool) (int, error)
Close() error
}
type HugotEngine struct {
session *hugot.Session
pipeline *pipelines.TextGenerationPipeline
embeddingPipeline *pipelines.FeatureExtractionPipeline
}
func NewHugotEngine(modelFolder string, speechModel string, embeddingModel string) (*HugotEngine, error) {
if err := os.MkdirAll(modelFolder, 0750); err != nil {
return nil, fmt.Errorf("failed to create model folder: %w", err)
}
ctx := context.Background()
opts := hugot.NewDownloadOptions()
session, err := hugot.NewORTSession(ctx)
if err != nil {
return nil, fmt.Errorf("failed to create hugot session: %w", err)
}
var chatPipeline *pipelines.TextGenerationPipeline
if speechModel != "" {
localPath := filepath.Join(modelFolder, speechModel)
if _, err := os.Stat(localPath); os.IsNotExist(err) {
slog.Info("Downloading model", "model", speechModel, "folder", modelFolder)
if _, err := hugot.DownloadModel(ctx, speechModel, modelFolder, opts); err != nil {
return nil, fmt.Errorf("failed to download speech model: %w", err)
}
}
config := hugot.TextGenerationConfig{
ModelPath: localPath,
}
chatPipeline, err = hugot.NewPipeline(session, config)
if err != nil {
return nil, fmt.Errorf("failed to create text generation pipeline: %w", err)
}
}
var embPipeline *pipelines.FeatureExtractionPipeline
if embeddingModel != "" {
localPath := filepath.Join(modelFolder, embeddingModel)
if _, err := os.Stat(localPath); os.IsNotExist(err) {
slog.Info("Downloading model", "model", embeddingModel, "folder", modelFolder)
if _, err := hugot.DownloadModel(ctx, embeddingModel, modelFolder, opts); err != nil {
return nil, fmt.Errorf("failed to download embedding model: %w", err)
}
}
embConfig := hugot.FeatureExtractionConfig{
ModelPath: localPath,
}
embPipeline, err = hugot.NewPipeline(session, embConfig)
if err != nil {
return nil, fmt.Errorf("failed to create feature extraction pipeline: %w", err)
}
}
return &HugotEngine{
session: session,
pipeline: chatPipeline,
embeddingPipeline: embPipeline,
}, nil
}
func (e *HugotEngine) Generate(ctx context.Context, prompt string, maxTokens int) (string, error) {
if e.pipeline == nil {
return "", fmt.Errorf("text generation pipeline not initialized")
}
if maxTokens <= 0 {
maxTokens = 256
}
e.pipeline.MaxLength = maxTokens
e.pipeline.Streaming = false
result, err := e.pipeline.RunPipeline(ctx, []string{prompt})
if err != nil {
return "", err
}
if len(result.Responses) > 0 {
return result.Responses[0], nil
}
return "", fmt.Errorf("no output generated")
}
func (e *HugotEngine) GenerateStream(ctx context.Context, prompt string, maxTokens int) (chan string, chan error, error) {
if e.pipeline == nil {
return nil, nil, fmt.Errorf("text generation pipeline not initialized")
}
if maxTokens <= 0 {
maxTokens = 256
}
e.pipeline.MaxLength = maxTokens
e.pipeline.Streaming = true
result, err := e.pipeline.RunPipeline(ctx, []string{prompt})
if err != nil {
return nil, nil, err
}
tokenChan := make(chan string)
go func() {
defer close(tokenChan)
for token := range result.TokenStream {
tokenChan <- token.Token
}
}()
return tokenChan, result.ErrorStream, nil
}
func (e *HugotEngine) ExtractEmbeddings(ctx context.Context, input []string) ([][]float32, error) {
if e.embeddingPipeline == nil {
return nil, fmt.Errorf("feature extraction pipeline not initialized")
}
result, err := e.embeddingPipeline.RunPipeline(ctx, input)
if err != nil {
return nil, err
}
return result.Embeddings, nil
}
func (e *HugotEngine) CountTokens(text string, isEmbedding bool) (int, error) {
var tok *backends.Tokenizer
if isEmbedding {
if e.embeddingPipeline != nil && e.embeddingPipeline.GetModel() != nil {
tok = e.embeddingPipeline.GetModel().Tokenizer
}
} else {
if e.pipeline != nil && e.pipeline.GetModel() != nil {
tok = e.pipeline.GetModel().Tokenizer
}
}
if tok == nil || tok.GoTokenizer == nil || tok.GoTokenizer.Tokenizer == nil {
// Fallback if tokenizer isn't loaded or isn't a GoTokenizer
return len(strings.Fields(text)), nil
}
ids := tok.GoTokenizer.Tokenizer.Encode(text)
return len(ids), nil
}
func (e *HugotEngine) Close() error {
if e.session != nil {
err := e.session.Destroy()
e.session = nil
return err
}
return nil
}
// MockEngine for testing purposes
type MockEngine struct {
Responses []string
Err error
StreamErr error
}
func (m *MockEngine) CountTokens(text string, isEmbedding bool) (int, error) {
return len(strings.Fields(text)), nil
}
func (m *MockEngine) ExtractEmbeddings(ctx context.Context, input []string) ([][]float32, error) {
if m.Err != nil {
return nil, m.Err
}
res := make([][]float32, len(input))
for i := range input {
// Just a dummy embedding of length 3
res[i] = []float32{0.1, 0.2, 0.3}
}
return res, nil
}
func (m *MockEngine) Generate(ctx context.Context, prompt string, maxTokens int) (string, error) {
if m.Err != nil {
return "", m.Err
}
if len(m.Responses) > 0 {
resp := m.Responses[0]
if len(m.Responses) > 1 {
m.Responses = m.Responses[1:]
}
return resp, nil
}
return "This is a mock response to: " + strings.TrimSpace(prompt), nil
}
func (m *MockEngine) GenerateStream(ctx context.Context, prompt string, maxTokens int) (chan string, chan error, error) {
if m.Err != nil {
return nil, nil, m.Err
}
tokenChan := make(chan string)
errChan := make(chan error)
go func() {
defer close(tokenChan)
defer close(errChan)
response := "This is a mock response to: " + strings.TrimSpace(prompt)
if len(m.Responses) > 0 {
response = m.Responses[0]
if len(m.Responses) > 1 {
m.Responses = m.Responses[1:]
}
}
if m.StreamErr != nil {
errChan <- m.StreamErr
return
}
words := strings.Split(response, " ")
for i, word := range words {
if i > 0 {
tokenChan <- " "
}
tokenChan <- word
time.Sleep(10 * time.Millisecond) // Simulate delay
}
}()
return tokenChan, errChan, nil
}
func (m *MockEngine) Close() error {
return nil
}
package core
import (
"encoding/json"
"github.com/siherrmann/talker/model"
"github.com/siherrmann/validator"
)
func ValidateJSON(output string, v *validator.Validator) error {
var target model.TargetJSONOutput
err := json.Unmarshal([]byte(output), &target)
if err != nil {
return err
}
return v.Validate(&target)
}
package handler
import (
"encoding/json"
"net/http"
"strings"
"time"
"github.com/google/uuid"
"github.com/labstack/echo/v5"
"github.com/siherrmann/talker/core"
"github.com/siherrmann/talker/metrics"
"github.com/siherrmann/talker/model"
"github.com/siherrmann/validator"
)
type ChatHandler struct {
Engine core.Engine
validator *validator.Validator
}
func NewChatHandler(engine core.Engine) *ChatHandler {
return &ChatHandler{
Engine: engine,
validator: validator.NewValidator(),
}
}
func (h *ChatHandler) ChatCompletions(c *echo.Context) error {
var req model.ChatCompletionRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": "Invalid request payload"})
}
if len(req.Messages) == 0 {
return c.JSON(http.StatusBadRequest, map[string]string{"error": "Messages array cannot be empty"})
}
// Format messages into a simple prompt template
var promptBuilder strings.Builder
for _, msg := range req.Messages {
promptBuilder.WriteString(msg.Role)
promptBuilder.WriteString(": ")
promptBuilder.WriteString(msg.Content)
promptBuilder.WriteString("\n")
}
promptBuilder.WriteString("assistant: ")
prompt := promptBuilder.String()
if req.Stream {
return h.handleStreaming(c, req, prompt)
}
return h.handleNonStreaming(c, req, prompt)
}
func (h *ChatHandler) handleStreaming(c *echo.Context, req model.ChatCompletionRequest, prompt string) error {
tokenChan, errChan, err := h.Engine.GenerateStream(c.Request().Context(), prompt, req.MaxTokens)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "Failed to generate completion stream: " + err.Error()})
}
c.Response().Header().Set(echo.HeaderContentType, "text/event-stream")
c.Response().Header().Set("Cache-Control", "no-cache")
c.Response().Header().Set("Connection", "keep-alive")
id := "chatcmpl-" + uuid.New().String()
created := time.Now().Unix()
for {
select {
case <-c.Request().Context().Done():
return nil
case err, ok := <-errChan:
if ok && err != nil {
return err
}
case token, ok := <-tokenChan:
if !ok {
stopReason := "stop"
chunk := model.ChatCompletionChunkResponse{
ID: id,
Object: "chat.completion.chunk",
Created: created,
Model: req.Model,
Choices: []model.ChunkChoice{
{
Index: 0,
Delta: model.ChunkDelta{},
FinishReason: &stopReason,
},
},
}
if err := streamChunk(c, chunk); err != nil {
return err
}
if _, err := c.Response().Write([]byte("data: [DONE]\n\n")); err != nil {
return err
}
if flusher, ok := c.Response().(http.Flusher); ok {
flusher.Flush()
}
return nil
}
// Record token consumption for streaming
labels := metrics.ExtractLabels(c, req.Model)
metrics.TokensConsumedTotal.WithLabelValues(labels...).Add(1)
chunk := model.ChatCompletionChunkResponse{
ID: id,
Object: "chat.completion.chunk",
Created: created,
Model: req.Model,
Choices: []model.ChunkChoice{
{
Index: 0,
Delta: model.ChunkDelta{
Content: token,
},
},
},
}
if err := streamChunk(c, chunk); err != nil {
return err
}
}
}
}
func (h *ChatHandler) handleNonStreaming(c *echo.Context, req model.ChatCompletionRequest, prompt string) error {
var output string
var err error
maxRetries := 3
if req.ResponseFormat != nil && req.ResponseFormat.Type == "json_object" {
currentPrompt := prompt
for i := 0; i < maxRetries; i++ {
output, err = h.Engine.Generate(c.Request().Context(), currentPrompt, req.MaxTokens)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "Failed to generate completion: " + err.Error()})
}
err = core.ValidateJSON(output, h.validator)
if err == nil {
break
}
if i == maxRetries-1 {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "Failed to generate valid JSON after multiple attempts: " + err.Error()})
}
currentPrompt += output + "\nThe JSON you provided was invalid. Error: " + err.Error() + ". Please try again and provide ONLY valid JSON.\nassistant: "
}
} else {
output, err = h.Engine.Generate(c.Request().Context(), prompt, req.MaxTokens)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "Failed to generate completion: " + err.Error()})
}
}
promptTokens, _ := h.Engine.CountTokens(prompt, false)
completionTokens, _ := h.Engine.CountTokens(output, false)
// Record total tokens for non-streaming
labels := metrics.ExtractLabels(c, req.Model)
metrics.TokensConsumedTotal.WithLabelValues(labels...).Add(float64(promptTokens + completionTokens))
resp := model.ChatCompletionResponse{
ID: "chatcmpl-" + uuid.New().String(),
Object: "chat.completion",
Created: time.Now().Unix(),
Model: req.Model,
Choices: []model.Choice{
{
Index: 0,
Message: model.ChatMessage{
Role: "assistant",
Content: output,
},
FinishReason: "stop",
},
},
Usage: model.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
},
}
return c.JSON(http.StatusOK, resp)
}
func streamChunk(c *echo.Context, chunk model.ChatCompletionChunkResponse) error {
b, err := json.Marshal(chunk)
if err != nil {
return err
}
if _, err := c.Response().Write([]byte("data: " + string(b) + "\n\n")); err != nil {
return err
}
if flusher, ok := c.Response().(http.Flusher); ok {
flusher.Flush()
}
return nil
}
package handler
import (
"net/http"
"github.com/labstack/echo/v5"
"github.com/siherrmann/talker/core"
"github.com/siherrmann/talker/metrics"
"github.com/siherrmann/talker/model"
)
type EmbeddingsHandler struct {
Engine core.Engine
}
func NewEmbeddingsHandler(engine core.Engine) *EmbeddingsHandler {
return &EmbeddingsHandler{
Engine: engine,
}
}
func (h *EmbeddingsHandler) Embeddings(c *echo.Context) error {
var req model.EmbeddingRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": "Invalid request payload: " + err.Error()})
}
if len(req.Input) == 0 {
return c.JSON(http.StatusBadRequest, map[string]string{"error": "Input cannot be empty"})
}
embeddings, err := h.Engine.ExtractEmbeddings(c.Request().Context(), req.Input)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "Failed to extract embeddings: " + err.Error()})
}
data := make([]model.Embedding, len(embeddings))
for i, emb := range embeddings {
data[i] = model.Embedding{
Object: "embedding",
Embedding: emb,
Index: i,
}
}
// Calculate exact usage using tokenizers
promptTokens := 0
for _, text := range req.Input {
count, _ := h.Engine.CountTokens(text, true)
promptTokens += count
}
labels := metrics.ExtractLabels(c, req.Model)
metrics.TokensConsumedTotal.WithLabelValues(labels...).Add(float64(promptTokens))
resp := model.EmbeddingResponse{
Object: "list",
Data: data,
Model: req.Model,
Usage: model.Usage{
PromptTokens: promptTokens,
TotalTokens: promptTokens,
CompletionTokens: 0,
},
}
return c.JSON(http.StatusOK, resp)
}
package main
import (
"log/slog"
"os"
"net/http"
"time"
"github.com/labstack/echo/v5"
"github.com/labstack/echo/v5/middleware"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/siherrmann/talker/core"
"github.com/siherrmann/talker/handler"
"github.com/siherrmann/talker/metrics"
)
func run() error {
// 1. Initialize the Engine
modelFolder := os.Getenv("MODEL_FOLDER")
speechModel := os.Getenv("CHAT_MODEL")
embeddingModel := os.Getenv("EMBEDDING_MODEL")
var engine core.Engine
var err error
if modelFolder == "" || (speechModel == "" && embeddingModel == "") {
slog.Warn("MODEL_FOLDER or model names not fully set, using MockEngine for testing.")
engine = &core.MockEngine{}
} else {
slog.Info("Initializing Hugot engine", "speech_model", speechModel, "embedding_model", embeddingModel, "folder", modelFolder)
engine, err = core.NewHugotEngine(modelFolder, speechModel, embeddingModel)
if err != nil {
return err
}
}
defer engine.Close()
// 2. Initialize Handlers
chatHandler := handler.NewChatHandler(engine)
embeddingsHandler := handler.NewEmbeddingsHandler(engine)
// 3. Setup Echo Server
e := echo.New()
// Middleware
e.Use(middleware.Recover())
e.Use(metrics.PrometheusMiddleware())
// Register Routes
e.POST("/v1/chat/completions", chatHandler.ChatCompletions)
e.POST("/v1/embeddings", embeddingsHandler.Embeddings)
// 4. Start Server
port := os.Getenv("PORT")
if port == "" {
port = "8080"
}
metricsPort := os.Getenv("METRICS_PORT")
if metricsPort != "" {
go func() {
slog.Info("Starting Prometheus Metrics Server", "port", metricsPort)
http.Handle("/metrics", promhttp.Handler())
server := &http.Server{
Addr: ":" + metricsPort,
ReadHeaderTimeout: 5 * time.Second,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
}
if err := server.ListenAndServe(); err != nil {
slog.Error("Metrics server failed", "error", err)
}
}()
} else {
slog.Info("METRICS_PORT not set, skipping Prometheus metrics server")
}
slog.Info("Starting Talker API", "port", port)
return e.Start(":" + port)
}
func main() {
if err := run(); err != nil {
slog.Error("Server failed", "error", err)
os.Exit(1)
}
}
package metrics
import (
"strconv"
"time"
"github.com/labstack/echo/v5"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
var (
labelNames = []string{"org_id", "project_id", "user_id", "model_name"}
RequestsTotal = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "talker_requests_total",
Help: "Total number of HTTP requests made to the API",
},
[]string{"method", "path", "status"},
)
TokensConsumedTotal = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "talker_tokens_consumed_total",
Help: "Total number of tokens consumed by the Talker API",
},
labelNames,
)
RequestDuration = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: "talker_request_duration_seconds",
Help: "Histogram of request latencies",
Buckets: prometheus.DefBuckets,
},
[]string{"method", "path", "status"},
)
)
// PrometheusMiddleware creates a middleware that records metrics for HTTP requests.
func PrometheusMiddleware() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) error {
start := time.Now()
err := next(c)
_, status := echo.ResolveResponseStatus(c.Response(), err)
duration := time.Since(start).Seconds()
method := c.Request().Method
path := c.Path()
statusStr := strconv.Itoa(status)
RequestsTotal.WithLabelValues(method, path, statusStr).Inc()
RequestDuration.WithLabelValues(method, path, statusStr).Observe(duration)
return err
}
}
}
// ExtractLabels attempts to extract project_id, org_id, and user_id from headers.
// If not found, defaults to "unknown".
func ExtractLabels(c *echo.Context, modelName string) []string {
orgID := c.Request().Header.Get("X-Org-Id")
if orgID == "" {
orgID = "unknown"
}
projectID := c.Request().Header.Get("X-Project-Id")
if projectID == "" {
projectID = "unknown"
}
userID := c.Request().Header.Get("X-User-Id")
if userID == "" {
userID = "unknown"
}
if modelName == "" {
modelName = "unknown"
}
return []string{orgID, projectID, userID, modelName}
}
package model
import (
"encoding/json"
"fmt"
)
type EmbeddingRequest struct {
Model string `json:"model"`
Input []string `json:"-"` // Handled by custom unmarshaler
}
func (r *EmbeddingRequest) UnmarshalJSON(data []byte) error {
type Alias EmbeddingRequest
aux := &struct {
Input json.RawMessage `json:"input"`
*Alias
}{
Alias: (*Alias)(r),
}
if err := json.Unmarshal(data, &aux); err != nil {
return err
}
if len(aux.Input) == 0 {
return fmt.Errorf("input is required")
}
// Try unmarshaling as string
var singleString string
if err := json.Unmarshal(aux.Input, &singleString); err == nil {
r.Input = []string{singleString}
return nil
}
// Try unmarshaling as []string
var arrayString []string
if err := json.Unmarshal(aux.Input, &arrayString); err == nil {
r.Input = arrayString
return nil
}
return fmt.Errorf("input must be a string or an array of strings")
}
type Embedding struct {
Object string `json:"object"`
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
}
type EmbeddingResponse struct {
Object string `json:"object"`
Data []Embedding `json:"data"`
Model string `json:"model"`
Usage Usage `json:"usage"`
}