// Copyright 2025 herosizy
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
package azopenai
import (
"context"
"errors"
"fmt"
"os"
"sync"
"github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit"
)
const (
azureOpenAIProvider = "azureopenai"
labelPrefix = "Azure OpenAI"
Provider = azureOpenAIProvider
LabelPrefix = labelPrefix
)
// AzureOpenAI is a Genkit plugin for interacting with the Azure OpenAI service.
type AzureOpenAI struct {
APIKey string // API key to access the service. If empty, the value of the environment variable AZURE_OPEN_AI_API_KEY will be consulted.
Endpoint string // Azure OpenAI endpoint. If empty, the value of the environment variable AZURE_OPEN_AI_ENDPOINT will be consulted.
client *azopenai.Client // Client for the Azure OpenAI service.
mu sync.Mutex // Mutex to control access.
initted bool // Whether the plugin has been initialized.
}
// Name returns the name of the plugin.
func (az *AzureOpenAI) Name() string {
return azureOpenAIProvider
}
// Init initializes the Azure OpenAI plugin and all known models.
// After calling Init, you may call [DefineModel] to create
// and register any additional generative models
func (az *AzureOpenAI) Init(ctx context.Context, g *genkit.Genkit) (err error) {
if az == nil {
az = &AzureOpenAI{}
}
az.mu.Lock()
defer az.mu.Unlock()
if az.initted {
return errors.New("plugin already initialized")
}
defer func() {
if err != nil {
err = fmt.Errorf("AzureOpenAI.Init: %w", err)
}
}()
apiKey := az.APIKey
if apiKey == "" {
apiKey = os.Getenv("AZURE_OPEN_AI_API_KEY")
if apiKey == "" {
return fmt.Errorf("Azure OpenAI requires setting AZURE_OPEN_AI_API_KEY in the environment")
}
}
endpoint := az.Endpoint
if endpoint == "" {
endpoint = os.Getenv("AZURE_OPEN_AI_ENDPOINT")
if endpoint == "" {
return fmt.Errorf("Azure OpenAI requires setting AZURE_OPEN_AI_ENDPOINT in the environment")
}
}
client, err := azopenai.NewClientWithKeyCredential(endpoint, azcore.NewKeyCredential(apiKey), &azopenai.ClientOptions{
ClientOptions: azcore.ClientOptions{
Telemetry: policy.TelemetryOptions{
Disabled: false,
},
},
})
if err != nil {
return err
}
az.client = client
az.initted = true
models, err := listModels()
if err != nil {
return err
}
// Register all supported models
for name, modelInfo := range models {
defineModel(g, az.client, name, modelInfo)
}
// Register embedding models
embeddingModels, err := listEmbedders()
if err != nil {
return err
}
for _, name := range embeddingModels {
defineEmbedder(g, az.client, name)
}
return nil
}
// DefineModel defines an unknown model with the given name.
// The second argument describes the capability of the model.
// Use [IsDefinedModel] to determine if a model is already defined.
// After [Init] is called, only the known models are defined.
func (az *AzureOpenAI) DefineModel(g *genkit.Genkit, name string, info *ai.ModelInfo) (ai.Model, error) {
az.mu.Lock()
defer az.mu.Unlock()
if !az.initted {
return nil, errors.New("AzureOpenAI plugin not initialized")
}
models, err := listModels()
if err != nil {
return nil, err
}
var mi ai.ModelInfo
if info == nil {
var ok bool
mi, ok = models[name]
if !ok {
return nil, fmt.Errorf("AzureOpenAI.DefineModel: called with unknown model %q and nil ModelInfo", name)
}
} else {
mi = *info
}
return defineModel(g, az.client, name, mi), nil
}
// Model returns a reference to the named model.
func Model(g *genkit.Genkit, name string) ai.Model {
return genkit.LookupModel(g, azureOpenAIProvider, name)
}
// ModelRef creates a model reference that can be used in flows.
func ModelRef(name string, config *OpenAIConfig) ai.ModelRef {
return ai.NewModelRef(azureOpenAIProvider+"/"+name, config)
}
// DefineModel allows users to define a custom model configuration.
func DefineModel(g *genkit.Genkit, name string, info *ai.ModelInfo) ai.Model {
return defineModel(g, nil, name, *info)
}
// IsDefinedModel checks if a model is already defined.
func IsDefinedModel(name string) bool {
model := genkit.LookupModel(nil, azureOpenAIProvider, name)
return model != nil
}
// Embedder returns an embedder with the given name.
func Embedder(g *genkit.Genkit, name string) ai.Embedder {
embedder := genkit.LookupEmbedder(g, azureOpenAIProvider, name)
if embedder == nil {
panic(fmt.Sprintf("Embedder %q was not found. Make sure you configured the Azure OpenAI plugin and that the embedder is supported.", name))
}
return embedder
}
// IsDefinedEmbedder checks if an embedder is supported
func IsDefinedEmbedder(name string) bool {
embeddingModels, err := listEmbedders()
if err != nil {
return false
}
for _, model := range embeddingModels {
if model == name {
return true
}
}
return false
}
// DefineEmbedder defines an embedder with a given name
func (a *AzureOpenAI) DefineEmbedder(g *genkit.Genkit, name string) (ai.Embedder, error) {
if !IsDefinedEmbedder(name) {
return nil, fmt.Errorf("embedder %s is not supported", name)
}
return defineEmbedder(g, a.client, name), nil
}
// IsDefinedEmbedder reports whether the named Embedder is defined by this plugin instance.
func (a *AzureOpenAI) IsDefinedEmbedder(g *genkit.Genkit, name string) bool {
return genkit.LookupEmbedder(g, azureOpenAIProvider, name) != nil
}
// Copyright 2025 herosizy
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
package azopenai
import (
"github.com/firebase/genkit/go/ai"
)
const (
//
// Reasoning models
// o-series models that excel at complex, multi-step tasks.
o4Mini = "o4-mini"
o3 = "o3"
o3Mini = "o3-mini"
o1 = "o1"
o1Mini = "o1-mini"
o1Pro = "o1-pro"
// Flagship chat models
// Our versatile, high-intelligence flagship models.
gpt41 = "gpt-4.1"
gpt41Mini = "gpt-4.1-mini"
gpt41Nano = "gpt-4.1-nano"
gpt4o = "gpt-4o"
gpt4oMini = "gpt-4o-mini"
gpt4oAudio = "gpt-4o-audio-preview"
gpt4oMiniAudio = "gpt-4o-mini-audio-preview"
chatgpt4o = "chatgpt-4o-latest"
// Image generation models
// Models that can generate and edit images, given a natural language prompt.
gptImage1 = "gpt-image-1"
dalle3 = "dall-e-3"
dalle2 = "dall-e-2"
// Embeddings
// A set of models that can convert text into vector representations.
textEmbedding3Large = "text-embedding-3-large"
textEmbedding3Small = "text-embedding-3-small"
// Older GPT models
// Supported older versions of our general purpose and chat models.
gpt35Turbo = "gpt-3.5-turbo"
gpt35TurboInstruct = "gpt-3.5-turbo-instruct"
gpt4 = "gpt-4"
gpt4Turbo = "gpt-4-turbo"
gpt4TurboPreview = "gpt-4-turbo-preview"
// Exported model constants for external use
Gpt35Turbo = gpt35Turbo
Gpt35TurboInstruct = gpt35TurboInstruct
Gpt4 = gpt4
Gpt4Turbo = gpt4Turbo
Gpt4TurboPreview = gpt4TurboPreview
Gpt41 = gpt41
Gpt41Mini = gpt41Mini
Gpt41Nano = gpt41Nano
Gpt4o = gpt4o
Gpt4oMini = gpt4oMini
Gpt4oAudio = gpt4oAudio
Gpt4oMiniAudio = gpt4oMiniAudio
O4Mini = o4Mini
O3 = o3
O3Mini = o3Mini
O1 = o1
O1Mini = o1Mini
O1Pro = o1Pro
Chatgpt4o = chatgpt4o
GptImage1 = gptImage1
Dalle3 = dalle3
Dalle2 = dalle2
TextEmbedding3Large = textEmbedding3Large
TextEmbedding3Small = textEmbedding3Small
)
var (
// List of supported Azure OpenAI models
azureOpenAIModels = []string{
gpt4,
gpt4Turbo,
gpt4TurboPreview,
gpt4o,
gpt4oMini,
gpt35Turbo,
gpt35TurboInstruct,
textEmbedding3Large,
textEmbedding3Small,
gpt41,
gpt41Mini,
o4Mini,
}
// Model capabilities for text models
TextModel = ai.ModelSupports{
Multiturn: true,
Tools: true,
ToolChoice: true,
SystemRole: true,
Media: false,
}
// Model capabilities for multimodal models
MultimodalModel = ai.ModelSupports{
Multiturn: true,
Tools: true,
ToolChoice: true,
SystemRole: true,
Media: true,
}
// supportedAzureOpenAIModels maps model names to their capabilities
supportedAzureOpenAIModels = map[string]ai.ModelInfo{
gpt4: {
Label: "GPT-4",
Versions: []string{
"gpt-4",
},
Supports: &TextModel,
Stage: ai.ModelStageStable,
},
gpt4Turbo: {
Label: "GPT-4 Turbo",
Versions: []string{
"gpt-4-turbo",
},
Supports: &MultimodalModel,
Stage: ai.ModelStageStable,
},
gpt4TurboPreview: {
Label: "GPT-4 Turbo Preview",
Versions: []string{
"gpt-4-turbo-preview",
},
Supports: &MultimodalModel,
Stage: ai.ModelStageUnstable,
},
gpt4o: {
Label: "GPT-4o",
Versions: []string{
"gpt-4o-2024-05-13",
"gpt-4o-2024-08-06",
},
Supports: &MultimodalModel,
Stage: ai.ModelStageStable,
},
gpt4oMini: {
Label: "GPT-4o Mini",
Versions: []string{
"gpt-4o-mini-2024-07-18",
},
Supports: &MultimodalModel,
Stage: ai.ModelStageStable,
},
gpt35Turbo: {
Label: "GPT-3.5 Turbo",
Versions: []string{
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-16k-0613",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-0125",
},
Supports: &TextModel,
Stage: ai.ModelStageStable,
},
gpt35TurboInstruct: {
Label: "GPT-3.5 Turbo Instruct",
Versions: []string{
"gpt-3.5-turbo-instruct-0914",
},
Supports: &TextModel,
Stage: ai.ModelStageStable,
},
textEmbedding3Large: {
Label: "Text Embedding 3 Large",
Versions: []string{
"text-embedding-3-large",
},
Supports: &ai.ModelSupports{
Multiturn: false,
Tools: false,
ToolChoice: false,
SystemRole: false,
Media: false,
},
Stage: ai.ModelStageStable,
},
textEmbedding3Small: {
Label: "Text Embedding 3 Small",
Versions: []string{
"text-embedding-3-small",
},
Supports: &ai.ModelSupports{
Multiturn: false,
Tools: false,
ToolChoice: false,
SystemRole: false,
Media: false,
},
Stage: ai.ModelStageStable,
},
gpt41: {
Label: "GPT-4.1",
Versions: []string{
"gpt-4.1",
},
Supports: &MultimodalModel,
Stage: ai.ModelStageUnstable,
},
gpt41Mini: {
Label: "GPT-4.1 Mini",
Versions: []string{
"gpt-4.1-mini",
},
Supports: &MultimodalModel,
Stage: ai.ModelStageUnstable,
},
o4Mini: {
Label: "O4 Mini",
Versions: []string{
"o4-mini",
},
Supports: &MultimodalModel,
Stage: ai.ModelStageUnstable,
},
}
)
// listModels returns a map of supported models and their capabilities
func listModels() (map[string]ai.ModelInfo, error) {
models := make(map[string]ai.ModelInfo, len(azureOpenAIModels))
for _, name := range azureOpenAIModels {
m, ok := supportedAzureOpenAIModels[name]
if !ok {
continue // Skip unknown models
}
models[name] = ai.ModelInfo{
Label: labelPrefix + " - " + m.Label,
Versions: m.Versions,
Supports: m.Supports,
Stage: m.Stage,
}
}
return models, nil
}
// listEmbedders returns the list of supported embedding models
func listEmbedders() ([]string, error) {
return []string{
textEmbedding3Large,
textEmbedding3Small,
}, nil
}
// Copyright 2025 herosizy
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
package azopenai
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"strings"
"github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit"
)
// OpenAIConfig represents the configuration options for Azure OpenAI models.
type OpenAIConfig struct {
ai.GenerationCommonConfig
DeploymentName string `json:"deploymentName,omitempty"` // Azure OpenAI deployment name
MaxTokens *int32 `json:"maxTokens,omitempty"` // Maximum number of tokens to generate
Temperature *float32 `json:"temperature,omitempty"` // Controls randomness (0.0 to 2.0)
TopP *float32 `json:"topP,omitempty"` // Nucleus sampling parameter
PresencePenalty *float32 `json:"presencePenalty,omitempty"` // Presence penalty (-2.0 to 2.0)
FrequencyPenalty *float32 `json:"frequencyPenalty,omitempty"` // Frequency penalty (-2.0 to 2.0)
LogitBias map[string]*int32 `json:"logitBias,omitempty"` // Logit bias modifications (fixed type)
User string `json:"user,omitempty"` // User identifier
Seed *int64 `json:"seed,omitempty"` // Random seed for deterministic outputs (fixed type)
}
// EmbedConfig contains configuration for embedding requests
type EmbedConfig struct {
DeploymentName string `json:"deploymentName,omitempty"`
User string `json:"user,omitempty"`
}
// defineModel creates and registers a model with Genkit
func defineModel(g *genkit.Genkit, client *azopenai.Client, name string, info ai.ModelInfo) ai.Model {
return genkit.DefineModel(g, azureOpenAIProvider, name, &info,
func(ctx context.Context, mr *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) {
// Extract config from request
var cfg OpenAIConfig
if mr.Config != nil {
if typedCfg, ok := mr.Config.(*OpenAIConfig); ok {
cfg = *typedCfg
}
}
if cfg.DeploymentName == "" {
cfg.DeploymentName = name
mr.Config = &cfg
}
// Convert Genkit request to Azure OpenAI format
azRequest, err := convertToAzureOpenAIRequest(mr, cfg)
if err != nil {
return nil, fmt.Errorf("failed to convert request: %w", err)
}
// Handle streaming vs non-streaming
if cb != nil {
return handleStreamingRequest(ctx, client, azRequest, cb)
} else {
return handleNonStreamingRequest(ctx, client, azRequest)
}
})
}
// convertToAzureOpenAIRequest converts a Genkit ModelRequest to Azure OpenAI format
func convertToAzureOpenAIRequest(mr *ai.ModelRequest, cfg OpenAIConfig) (azopenai.ChatCompletionsOptions, error) {
messages := make([]azopenai.ChatRequestMessageClassification, 0, len(mr.Messages))
for _, msg := range mr.Messages {
azMsg, err := convertMessage(msg)
if err != nil {
return azopenai.ChatCompletionsOptions{}, err
}
messages = append(messages, azMsg)
}
deploymentName := cfg.DeploymentName
if deploymentName == "" {
return azopenai.ChatCompletionsOptions{}, errors.New("deployment name is required")
}
options := azopenai.ChatCompletionsOptions{
Messages: messages,
DeploymentName: &deploymentName,
}
// Apply configuration options
if cfg.MaxTokens != nil {
options.MaxTokens = cfg.MaxTokens
}
if cfg.Temperature != nil {
options.Temperature = cfg.Temperature
}
if cfg.TopP != nil {
options.TopP = cfg.TopP
}
if cfg.PresencePenalty != nil {
options.PresencePenalty = cfg.PresencePenalty
}
if cfg.FrequencyPenalty != nil {
options.FrequencyPenalty = cfg.FrequencyPenalty
}
if len(cfg.LogitBias) > 0 {
options.LogitBias = cfg.LogitBias // Now the types match
}
if cfg.User != "" {
options.User = &cfg.User
}
if cfg.Seed != nil {
options.Seed = cfg.Seed // Now the types match
}
// Handle tools if present
if len(mr.Tools) > 0 {
tools, err := convertTools(mr.Tools)
if err != nil {
return azopenai.ChatCompletionsOptions{}, err
}
options.Tools = tools
}
return options, nil
}
// convertMessage converts a Genkit message to Azure OpenAI format
func convertMessage(msg *ai.Message) (azopenai.ChatRequestMessageClassification, error) {
content := extractTextContent(msg.Content)
switch msg.Role {
case ai.RoleSystem:
return &azopenai.ChatRequestSystemMessage{
Content: azopenai.NewChatRequestSystemMessageContent(content),
}, nil
case ai.RoleUser:
return &azopenai.ChatRequestUserMessage{
Content: azopenai.NewChatRequestUserMessageContent(content),
}, nil
case ai.RoleModel:
return &azopenai.ChatRequestAssistantMessage{
Content: azopenai.NewChatRequestAssistantMessageContent(content), // Fixed type
}, nil
case ai.RoleTool:
// Tool messages need special handling
return &azopenai.ChatRequestToolMessage{
Content: azopenai.NewChatRequestToolMessageContent(content), // Fixed type
ToolCallID: to.Ptr("tool_call_id"), // This should be properly tracked
}, nil
default:
return nil, fmt.Errorf("unsupported role: %s", msg.Role)
}
}
// extractTextContent extracts text content from message parts
func extractTextContent(parts []*ai.Part) string {
var textParts []string
for _, part := range parts {
if part.IsText() {
textParts = append(textParts, part.Text)
}
// TODO: Handle media parts for multimodal models
}
return strings.Join(textParts, "")
}
// convertTools converts Genkit tools to Azure OpenAI format
func convertTools(tools []*ai.ToolDefinition) ([]azopenai.ChatCompletionsToolDefinitionClassification, error) {
azTools := make([]azopenai.ChatCompletionsToolDefinitionClassification, len(tools))
for i, tool := range tools {
// Convert the input schema to JSON bytes
parametersBytes, err := json.Marshal(tool.InputSchema)
if err != nil {
return nil, fmt.Errorf("failed to marshal tool parameters: %w", err)
}
azTools[i] = &azopenai.ChatCompletionsFunctionToolDefinition{
Type: to.Ptr("function"),
Function: &azopenai.ChatCompletionsFunctionToolDefinitionFunction{
Name: &tool.Name,
Description: &tool.Description,
Parameters: parametersBytes, // Fixed type
},
}
}
return azTools, nil
}
// handleStreamingRequest handles streaming chat completions
func handleStreamingRequest(ctx context.Context, client *azopenai.Client, options azopenai.ChatCompletionsOptions, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) {
resp, err := client.GetChatCompletionsStream(ctx, azopenai.ChatCompletionsStreamOptions{
Messages: options.Messages,
DeploymentName: options.DeploymentName,
MaxTokens: options.MaxTokens,
Temperature: options.Temperature,
TopP: options.TopP,
PresencePenalty: options.PresencePenalty,
FrequencyPenalty: options.FrequencyPenalty,
LogitBias: options.LogitBias,
User: options.User,
Seed: options.Seed,
Tools: options.Tools,
N: to.Ptr[int32](1),
}, nil)
if err != nil {
return nil, fmt.Errorf("failed to get chat completions stream: %w", err)
}
defer resp.ChatCompletionsStream.Close()
var fullContent strings.Builder
var finishReason ai.FinishReason
for {
chatCompletion, err := resp.ChatCompletionsStream.Read()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return nil, fmt.Errorf("failed to read chat completion: %w", err)
}
for _, choice := range chatCompletion.Choices {
if choice.Delta.Content != nil {
content := *choice.Delta.Content
fullContent.WriteString(content)
// Call the streaming callback
if cb != nil {
chunk := &ai.ModelResponseChunk{ // Fixed type
Content: []*ai.Part{ai.NewTextPart(content)},
Role: ai.RoleModel,
}
if err := cb(ctx, chunk); err != nil {
return nil, fmt.Errorf("streaming callback error: %w", err)
}
}
}
if choice.FinishReason != nil {
finishReason = convertFinishReason(*choice.FinishReason)
}
}
}
// Return the final response
return &ai.ModelResponse{
Message: &ai.Message{ // Fixed structure
Content: []*ai.Part{ai.NewTextPart(fullContent.String())},
Role: ai.RoleModel,
},
FinishReason: finishReason,
}, nil
}
// handleNonStreamingRequest handles non-streaming chat completions
func handleNonStreamingRequest(ctx context.Context, client *azopenai.Client, options azopenai.ChatCompletionsOptions) (*ai.ModelResponse, error) {
resp, err := client.GetChatCompletions(ctx, options, nil)
if err != nil {
return nil, fmt.Errorf("failed to get chat completions: %w", err)
}
if len(resp.Choices) == 0 {
return nil, errors.New("no choices returned from Azure OpenAI")
}
choice := resp.Choices[0]
content := ""
if choice.Message.Content != nil {
content = *choice.Message.Content
}
finishReason := ai.FinishReasonStop
if choice.FinishReason != nil {
finishReason = convertFinishReason(*choice.FinishReason)
}
return &ai.ModelResponse{
Message: &ai.Message{ // Fixed structure
Content: []*ai.Part{ai.NewTextPart(content)},
Role: ai.RoleModel,
},
FinishReason: finishReason,
}, nil
}
// convertFinishReason converts Azure OpenAI finish reason to Genkit format
func convertFinishReason(reason azopenai.CompletionsFinishReason) ai.FinishReason {
switch reason {
case azopenai.CompletionsFinishReasonStopped:
return ai.FinishReasonStop
case azopenai.CompletionsFinishReasonTokenLimitReached:
return ai.FinishReasonLength
case azopenai.CompletionsFinishReasonContentFiltered:
return ai.FinishReasonBlocked
case azopenai.CompletionsFinishReasonToolCalls:
return ai.FinishReasonStop // TODO: Handle tool calls properly
default:
return ai.FinishReasonOther
}
}
// defineEmbedder creates a new embedder for the specified embedding model
func defineEmbedder(g *genkit.Genkit, client *azopenai.Client, name string) ai.Embedder {
return genkit.DefineEmbedder(g, azureOpenAIProvider, name, func(ctx context.Context, req *ai.EmbedRequest) (*ai.EmbedResponse, error) {
// Extract configuration from request options
var config *EmbedConfig
if opts, ok := req.Options.(*EmbedConfig); ok {
config = opts
} else {
// Use default config with the model name as deployment name
config = &EmbedConfig{
DeploymentName: name,
}
}
// Convert input documents to strings
var input []string
for _, doc := range req.Input {
// Extract text content from each document
var textParts []string
for _, part := range doc.Content {
if part.Text != "" {
textParts = append(textParts, part.Text)
}
}
if len(textParts) > 0 {
input = append(input, strings.Join(textParts, " "))
}
}
if len(input) == 0 {
return nil, fmt.Errorf("no text content found in input documents")
}
// Call Azure OpenAI embeddings API
body := azopenai.EmbeddingsOptions{
Input: input,
DeploymentName: to.Ptr(config.DeploymentName),
}
if config.User != "" {
body.User = to.Ptr(config.User)
}
resp, err := client.GetEmbeddings(ctx, body, nil)
if err != nil {
return nil, fmt.Errorf("failed to get embeddings from Azure OpenAI: %w", err)
}
// Convert Azure OpenAI response to Genkit format
var embeddings []*ai.Embedding
for _, item := range resp.Data {
embeddings = append(embeddings, &ai.Embedding{
Embedding: item.Embedding,
})
}
return &ai.EmbedResponse{
Embeddings: embeddings,
}, nil
})
}