// Copyright 2025 OpenPubkey
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package commands
import (
        "errors"
        "fmt"
        "os"
        "github.com/openpubkey/opkssh/policy"
)
// AddCmd provides functionality to read and update the opkssh policy file
type AddCmd struct {
        HomePolicyLoader   *policy.HomePolicyLoader
        SystemPolicyLoader *policy.SystemPolicyLoader
        // Username is the username to lookup when the system policy file cannot be
        // read and we fallback to the user's policy file.
        //
        // See AddCmd.LoadPolicy for more details.
        Username string
}
// LoadPolicy reads the opkssh policy at the policy.SystemDefaultPolicyPath. If
// there is a permission error when reading this file, then the user's local
// policy file (defined as ~/.opk/auth_id where ~ maps to AddCmd.Username's
// home directory) is read instead.
//
// If successful, returns the parsed policy and filepath used to read the
// policy. Otherwise, a non-nil error is returned.
func (a *AddCmd) LoadPolicy() (*policy.Policy, string, error) {
        // Try to read system policy first
        systemPolicy, _, err := a.SystemPolicyLoader.LoadSystemPolicy()
        if err != nil {
                if errors.Is(err, os.ErrPermission) {
                        // If current process doesn't have permission, try reading the user
                        // policy file.
                        userPolicy, policyFilePath, err := a.HomePolicyLoader.LoadHomePolicy(a.Username, false)
                        if err != nil {
                                return nil, "", err
                        }
                        return userPolicy, policyFilePath, nil
                } else {
                        // Non-permission error (e.g. system policy file missing or invalid
                        // permission bits set). Return error
                        return nil, "", err
                }
        }
        return systemPolicy, policy.SystemDefaultPolicyPath, nil
}
// GetPolicyPath returns the path to the policy file that the current command
// will write to and a boolean to flag the path is for home policy.
// True means home policy, false means system policy.
func (a *AddCmd) GetPolicyPath(principal string, userEmail string, issuer string) (string, bool, error) {
        // Try to read system policy first
        _, _, err := a.SystemPolicyLoader.LoadSystemPolicy()
        if err != nil {
                if errors.Is(err, os.ErrPermission) {
                        // If current process doesn't have permission, try reading the user
                        // policy file.
                        policyFilePath, err := a.HomePolicyLoader.UserPolicyPath(a.Username)
                        if err != nil {
                                return "", false, err
                        }
                        return policyFilePath, false, nil
                } else {
                        // Non-permission error (e.g. system policy file missing or invalid
                        // permission bits set). Return error
                        return "", false, err
                }
        }
        return policy.SystemDefaultPolicyPath, true, nil
}
// Run adds a new allowed principal to the user whose email is equal to
// userEmail. The policy file is read and modified.
//
// If successful, returns the policy filepath updated. Otherwise, returns a
// non-nil error
func (a *AddCmd) Run(principal string, userEmail string, issuer string) (string, error) {
        policyPath, useSystemPolicy, err := a.GetPolicyPath(principal, userEmail, issuer)
        if err != nil {
                return "", fmt.Errorf("failed to load policy: %w", err)
        }
        var policyLoader *policy.PolicyLoader
        if useSystemPolicy {
                policyLoader = a.SystemPolicyLoader.PolicyLoader
        } else {
                policyLoader = a.HomePolicyLoader.PolicyLoader
        }
        err = policyLoader.CreateIfDoesNotExist(policyPath)
        if err != nil {
                return "", fmt.Errorf("failed to create policy file: %w", err)
        }
        // Read current policy
        currentPolicy, policyFilePath, err := a.LoadPolicy()
        if err != nil {
                return "", fmt.Errorf("failed to load current policy: %w", err)
        }
        // Update policy
        currentPolicy.AddAllowedPrincipal(principal, userEmail, issuer)
        // Dump contents back to disk
        err = policyLoader.Dump(currentPolicy, policyFilePath)
        if err != nil {
                return "", fmt.Errorf("failed to write updated policy: %w", err)
        }
        return policyFilePath, nil
}
		
		// Copyright 2025 OpenPubkey
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package config
import (
        _ "embed"
        "fmt"
        "log"
        "os"
        "path/filepath"
        "github.com/spf13/afero"
        "gopkg.in/yaml.v3"
)
//go:embed default-client-config.yml
var DefaultClientConfig []byte
type ClientConfig struct {
        DefaultProvider string           `yaml:"default_provider"`
        Providers       []ProviderConfig `yaml:"providers"`
}
func NewClientConfig(c []byte) (*ClientConfig, error) {
        var clientConfig ClientConfig
        if err := yaml.Unmarshal(c, &clientConfig); err != nil {
                return nil, err
        }
        return &clientConfig, nil
}
func (c *ClientConfig) GetProvidersMap() (map[string]ProviderConfig, error) {
        return CreateProvidersMap(c.Providers)
}
// GetByIssuer looks up an OpenID Provider by its issuer URL. If there are
// multiple providers with the same issuer, it returns the first one found.
func (c *ClientConfig) GetByIssuer(issuer string) (*ProviderConfig, bool) {
        for _, provider := range c.Providers {
                if provider.Issuer == issuer {
                        return &provider, true
                }
        }
        return nil, false
}
func ResolveClientConfigPath(configPath *string) error {
        if *configPath == "" {
                dir, dirErr := os.UserHomeDir()
                if dirErr != nil {
                        return fmt.Errorf("failed to get user config dir: %w", dirErr)
                }
                *configPath = filepath.Join(dir, ".opk", "config.yml")
        }
        return nil
}
// GetClientConfigFromFile retrieves the client config from the configuration file at configPath.
// If configPath is not specified then the default configuration path is uses ~/.opk/config.yml
func GetClientConfigFromFile(configPath string, Fs afero.Fs) (*ClientConfig, error) {
        if err := ResolveClientConfigPath(&configPath); err != nil {
                return nil, err
        }
        var configBytes []byte
        // Load the file from the filesystem
        afs := &afero.Afero{Fs: Fs}
        configBytes, err := afs.ReadFile(configPath)
        if err != nil {
                return nil, fmt.Errorf("failed to read config file: %w", err)
        }
        config, err := NewClientConfig(configBytes)
        if err != nil {
                return nil, fmt.Errorf("failed to parse config file: %w", err)
        }
        return config, nil
}
func CreateDefaultClientConfig(configPath string, Fs afero.Fs) error {
        afs := &afero.Afero{Fs: Fs}
        if err := afs.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
                return fmt.Errorf("failed to create config directory: %w", err)
        }
        if err := afs.WriteFile(configPath, DefaultClientConfig, 0o644); err != nil {
                return fmt.Errorf("failed to write default config file: %w", err)
        }
        log.Printf("created client config file at %s", configPath)
        return nil
}
		
		// Copyright 2025 OpenPubkey
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package config
import (
        "fmt"
        "os"
        "strings"
        "github.com/openpubkey/openpubkey/providers"
        "gopkg.in/yaml.v3"
)
const (
        WEBCHOOSER_ALIAS        = "WEBCHOOSER"
        OPKSSH_DEFAULT_ENVVAR   = "OPKSSH_DEFAULT"
        OPKSSH_PROVIDERS_ENVVAR = "OPKSSH_PROVIDERS"
)
type ProviderConfig struct {
        AliasList       []string `yaml:"alias"`
        Issuer          string   `yaml:"issuer"`
        ClientID        string   `yaml:"client_id"`
        ClientSecret    string   `yaml:"client_secret,omitempty"`
        Scopes          []string `yaml:"scopes"`
        AccessType      string   `yaml:"access_type,omitempty"`
        Prompt          string   `yaml:"prompt,omitempty"`
        RedirectURIs    []string `yaml:"redirect_uris"`
        SendAccessToken bool     `yaml:"send_access_token,omitempty"`
}
func (p *ProviderConfig) UnmarshalYAML(value *yaml.Node) error {
        var tmp struct {
                AliasList       string   `yaml:"alias"`
                Issuer          string   `yaml:"issuer"`
                ClientID        string   `yaml:"client_id"`
                ClientSecret    string   `yaml:"client_secret"`
                Scopes          string   `yaml:"scopes"`
                AccessType      string   `yaml:"access_type"`
                Prompt          string   `yaml:"prompt"`
                RedirectURIs    []string `yaml:"redirect_uris"`
                SendAccessToken bool     `yaml:"send_access_token,omitempty"`
        }
        // Set default values
        tmp.Scopes = "openid profile email"
        tmp.AccessType = "offline"
        tmp.Prompt = "consent"
        tmp.RedirectURIs = []string{
                "http://localhost:3000/login-callback",
                "http://localhost:10001/login-callback",
                "http://localhost:11110/login-callback",
        }
        if err := value.Decode(&tmp); err != nil {
                return err
        }
        *p = ProviderConfig{
                AliasList:       strings.Fields(tmp.AliasList),
                Issuer:          tmp.Issuer,
                ClientID:        tmp.ClientID,
                ClientSecret:    tmp.ClientSecret,
                Scopes:          strings.Fields(tmp.Scopes),
                AccessType:      tmp.AccessType,
                Prompt:          tmp.Prompt,
                RedirectURIs:    tmp.RedirectURIs,
                SendAccessToken: tmp.SendAccessToken,
        }
        return nil
}
// TODO: Move this into OpenPubkey providers package
func DefaultProviderConfig() ProviderConfig {
        return ProviderConfig{
                AliasList:    []string{},
                Issuer:       "",
                ClientID:     "",
                ClientSecret: "",
                Scopes:       []string{"openid", "email"},
                AccessType:   "offline",
                RedirectURIs: []string{
                        "http://localhost:3000/login-callback",
                        "http://localhost:10001/login-callback",
                        "http://localhost:11110/login-callback",
                },
                Prompt: "consent",
        }
}
func GitHubProviderConfig() ProviderConfig {
        return ProviderConfig{
                AliasList: []string{"github"},
                Issuer:    "https://token.actions.githubusercontent.com",
                // This is required, but is not used for this provider.
                ClientID: "unused",
        }
}
// NewProviderConfigFromString is a function to create the provider config from a string of the format
// {alias},{provider_url},{client_id},{client_secret},{scopes}
func NewProviderConfigFromString(configStr string, hasAlias bool) (ProviderConfig, error) {
        parts := strings.Split(configStr, ",")
        alias := ""
        if hasAlias {
                // If the config string has an alias, we need to remove it from the parts
                alias = parts[0]
                parts = parts[1:]
        }
        if len(parts) < 2 {
                if hasAlias {
                        return ProviderConfig{}, fmt.Errorf("invalid provider config string. Expected format <alias>,<issuer>,<client_id> or <alias>,<issuer>,<client_id>,<client_secret> or <alias>,<issuer>,<client_id>,<client_secret>,<scopes>")
                }
                return ProviderConfig{}, fmt.Errorf("invalid provider config string. Expected format <issuer>,<client_id> or <issuer>,<client_id>,<client_secret> or <issuer>,<client_id>,<client_secret>,<scopes>")
        }
        providerConfig := DefaultProviderConfig()
        providerConfig.AliasList = []string{alias}
        providerConfig.Issuer = parts[0]
        providerConfig.ClientID = parts[1]
        if providerConfig.ClientID == "" {
                return ProviderConfig{}, fmt.Errorf("invalid provider client-ID value got (%s)", providerConfig.ClientID)
        }
        if len(parts) > 2 {
                providerConfig.ClientSecret = parts[2]
        } else {
                providerConfig.ClientSecret = ""
        }
        if len(parts) > 3 {
                providerConfig.Scopes = strings.Split(parts[3], " ")
        } else {
                providerConfig.Scopes = []string{"openid", "email"}
        }
        if strings.HasPrefix(providerConfig.Issuer, "https://accounts.google.com") {
                // The Google OP is strange in that it requires a client secret even if this is a public OIDC App.
                // Despite its name the Google OP client secret is a public value.
                if providerConfig.ClientSecret == "" {
                        if hasAlias {
                                return ProviderConfig{}, fmt.Errorf("invalid provider argument format. Expected format for google: <alias>,<issuer>,<client_id>,<client_secret>")
                        } else {
                                return ProviderConfig{}, fmt.Errorf("invalid provider argument format. Expected format for google: <issuer>,<client_id>,<client_secret>")
                        }
                }
        }
        return providerConfig, nil
}
// NewProviderFromConfig is a function to create the provider from the config
func (p *ProviderConfig) ToProvider(openBrowser bool) (providers.OpenIdProvider, error) {
        if p.Issuer == "" {
                return nil, fmt.Errorf("invalid provider issuer value got (%s)", p.Issuer)
        }
        if !strings.HasPrefix(p.Issuer, "https://") {
                return nil, fmt.Errorf("invalid provider issuer value. Expected issuer to start with 'https://' got (%s)", p.Issuer)
        }
        if p.ClientID == "" {
                return nil, fmt.Errorf("invalid provider client-ID value got (%s)", p.ClientID)
        }
        var provider providers.OpenIdProvider
        if strings.HasPrefix(p.Issuer, "https://accounts.google.com") {
                opts := providers.GetDefaultGoogleOpOptions()
                opts.Issuer = p.Issuer
                opts.ClientID = p.ClientID
                opts.ClientSecret = p.ClientSecret
                opts.GQSign = false
                if p.hasScopes() {
                        opts.Scopes = p.Scopes
                }
                opts.PromptType = p.Prompt
                opts.AccessType = p.AccessType
                opts.RedirectURIs = p.RedirectURIs
                opts.OpenBrowser = openBrowser
                provider = providers.NewGoogleOpWithOptions(opts)
        } else if strings.HasPrefix(p.Issuer, "https://login.microsoftonline.com") {
                opts := providers.GetDefaultAzureOpOptions()
                opts.Issuer = p.Issuer
                opts.ClientID = p.ClientID
                opts.GQSign = false
                if p.hasScopes() {
                        opts.Scopes = p.Scopes
                }
                opts.PromptType = p.Prompt
                opts.AccessType = p.AccessType
                opts.RedirectURIs = p.RedirectURIs
                opts.OpenBrowser = openBrowser
                provider = providers.NewAzureOpWithOptions(opts)
        } else if strings.HasPrefix(p.Issuer, "https://gitlab.com") {
                opts := providers.GetDefaultGitlabOpOptions()
                opts.Issuer = p.Issuer
                opts.ClientID = p.ClientID
                opts.GQSign = false
                if p.hasScopes() {
                        opts.Scopes = p.Scopes
                }
                opts.PromptType = p.Prompt
                opts.AccessType = p.AccessType
                opts.RedirectURIs = p.RedirectURIs
                opts.OpenBrowser = openBrowser
                provider = providers.NewGitlabOpWithOptions(opts)
        } else if p.Issuer == "https://issuer.hello.coop" {
                opts := providers.GetDefaultHelloOpOptions()
                opts.Issuer = p.Issuer
                opts.ClientID = p.ClientID
                opts.GQSign = false
                if p.hasScopes() {
                        opts.Scopes = p.Scopes
                }
                opts.PromptType = p.Prompt
                opts.AccessType = p.AccessType
                opts.RedirectURIs = p.RedirectURIs
                opts.OpenBrowser = openBrowser
                provider = providers.NewHelloOpWithOptions(opts)
        } else if strings.HasPrefix(p.Issuer, "https://token.actions.githubusercontent.com") {
                githubOp, err := providers.NewGithubOpFromEnvironment()
                if err != nil {
                        return nil, fmt.Errorf("error creating github op: %w", err)
                }
                provider = githubOp
        } else {
                // Generic provider
                opts := providers.GetDefaultStandardOpOptions(p.Issuer, p.ClientID)
                opts.ClientSecret = p.ClientSecret
                opts.PromptType = p.Prompt
                opts.AccessType = p.AccessType
                opts.RedirectURIs = p.RedirectURIs
                opts.GQSign = false
                if p.hasScopes() {
                        opts.Scopes = p.Scopes
                }
                opts.OpenBrowser = openBrowser
                provider = providers.NewStandardOpWithOptions(opts)
        }
        return provider, nil
}
func (p *ProviderConfig) hasScopes() bool {
        return len(p.Scopes) > 0 && (len(p.Scopes) > 1 || p.Scopes[0] != "")
}
// GetProvidersConfigFromEnv is a function to retrieve the config from the env variables
// OPKSSH_DEFAULT can be set to an alias
// OPKSSH_PROVIDERS is a ; separated list of providers of the format <alias>,<issuer>,<client_id>,<client_secret>,<scopes>;<alias>,<issuer>,<client_id>,<client_secret>,<scopes>
func GetProvidersConfigFromEnv() ([]ProviderConfig, error) {
        // Get the providers from the env variable
        providerList, ok := os.LookupEnv(OPKSSH_PROVIDERS_ENVVAR)
        if !ok || providerList == "" {
                return nil, nil
        }
        if providerConfigList, err := ProvidersConfigListFromStrings(providerList); err != nil {
                return nil, fmt.Errorf("error getting provider config from env: %w", err)
        } else {
                return providerConfigList, nil
        }
}
func ProvidersConfigListFromStrings(providerList string) ([]ProviderConfig, error) {
        providerConfigList := make([]ProviderConfig, 0)
        for _, providerStr := range strings.Split(providerList, ";") {
                providerConfig, err := NewProviderConfigFromString(providerStr, true)
                if err != nil {
                        return nil, fmt.Errorf("error parsing provider config string: %w", err)
                }
                providerConfigList = append(providerConfigList, providerConfig)
        }
        return providerConfigList, nil
}
func CreateProvidersMap(providerConfigList []ProviderConfig) (map[string]ProviderConfig, error) {
        providersConfig := make(map[string]ProviderConfig)
        for _, providerConfig := range providerConfigList {
                for _, alias := range providerConfig.AliasList {
                        // If alias already exists, return an error
                        if _, ok := providersConfig[alias]; ok {
                                return nil, fmt.Errorf("duplicate provider alias found: %s", alias)
                        }
                        providersConfig[alias] = providerConfig
                }
        }
        return providersConfig, nil
}
		
		// Copyright 2025 OpenPubkey
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package config
import (
        "os"
        "gopkg.in/yaml.v3"
)
// ServerConfig struct to represent the /etc/opk/config.yml file that runs on the server that the user is SSHing into
type ServerConfig struct {
        EnvVars    map[string]string `yaml:"env_vars"`
        DenyUsers  []string          `yaml:"deny_users"`
        DenyEmails []string          `yaml:"deny_emails"`
}
func NewServerConfig(c []byte) (*ServerConfig, error) {
        var serverConfig ServerConfig
        if err := yaml.Unmarshal(c, &serverConfig); err != nil {
                return nil, err
        }
        return &serverConfig, nil
}
func (c *ServerConfig) SetEnvVars() error {
        for k, v := range c.EnvVars {
                if err := os.Setenv(k, v); err != nil {
                        return err
                }
        }
        return nil
}
		
		// SPDX-License-Identifier: Apache-2.0
package commands
import (
        "encoding/base64"
        "encoding/json"
        "fmt"
        "io"
        "os"
        "strings"
        "time"
        "github.com/openpubkey/openpubkey/pktoken"
        "golang.org/x/crypto/ssh"
)
type InspectCmd struct {
        // KeyOrCert is the SSH key or certificate to be inspected.
        KeyOrCert string
        // Output is where output should be written to.
        Output io.Writer
}
// NewInspectCmd creates a new InspectCmd instance with the provided arguments.
func NewInspectCmd(keyOrCert string, output io.Writer) *InspectCmd {
        return &InspectCmd{
                KeyOrCert: keyOrCert,
                Output:    output,
        }
}
// printf formats a string to the configured output.
func (i *InspectCmd) printf(format string, a ...any) {
        if _, err := fmt.Fprintf(i.Output, format, a...); err != nil {
                // Fall back to stdout
                i.printf(format, a...)
        }
}
func (i *InspectCmd) Run() error {
        // Check if the input is a file path
        if _, err := os.Stat(i.KeyOrCert); err == nil {
                // It's a file, read its contents
                data, err := os.ReadFile(i.KeyOrCert)
                if err != nil {
                        return fmt.Errorf("error reading input file: %v", err)
                }
                i.KeyOrCert = string(data)
        }
        // Trim whitespace and newlines
        i.KeyOrCert = strings.TrimSpace(i.KeyOrCert)
        // Parse the SSH key or certificate
        pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(i.KeyOrCert))
        if err != nil {
                return fmt.Errorf("failed to parse SSH key: %v", err)
        }
        // Check if it's a certificate
        if cert, ok := pubKey.(*ssh.Certificate); ok {
                i.inspectCertificate(cert)
        } else {
                // It's a regular public key
                i.inspectPublicKey(pubKey)
        }
        return nil
}
func (i *InspectCmd) inspectCertificate(cert *ssh.Certificate) {
        i.printf("--- SSH Certificate Information ---\n")
        i.printf("%-18s %d\n", "Serial:", cert.Serial)
        i.printf("%-18s %s\n", "Type:", certificateType(cert.CertType))
        i.printf("%-18s %s\n", "Key ID:", cert.KeyId)
        i.printf("%-18s %v\n", "Principals:", cert.ValidPrincipals)
        i.printf("%-18s %s\n", "Valid After:", formatTime(cert.ValidAfter))
        i.printf("%-18s %s\n", "Valid Before:", formatTime(cert.ValidBefore))
        i.printf("%-18s %v\n", "Critical Options:", cert.CriticalOptions)
        // Format extensions nicely
        i.printf("Extensions:\n")
        for key, value := range cert.Extensions {
                if key == "openpubkey-pkt" {
                        i.printf("  %s: [PKToken data] %d bytes\n", key, len(value))
                } else {
                        i.printf("  %s: %s\n", key, value)
                }
        }
        // Extract openpubkey-pkt extension if it exists
        pktStr, ok := cert.Extensions["openpubkey-pkt"]
        if !ok {
                i.printf("\nNo openpubkey-pkt extension found\n")
                return
        }
        i.inspectPKToken(pktStr)
}
// formatTime converts a Unix timestamp to a readable date string
func formatTime(timestamp uint64) string {
        if timestamp == 0 {
                return "Not set"
        }
        if timestamp == 1<<64-1 {
                return "Forever"
        }
        t := time.Unix(int64(timestamp), 0)
        return t.Format(time.RFC3339)
}
func (i *InspectCmd) inspectPublicKey(pubKey ssh.PublicKey) {
        i.printf("--- SSH Public Key Information ---\n")
        i.printf("Type: %s\n", pubKey.Type())
        // Get fingerprint
        fingerprint := ssh.FingerprintSHA256(pubKey)
        i.printf("Fingerprint: %s\n", fingerprint)
        // Get marshal format
        marshal := base64.StdEncoding.EncodeToString(pubKey.Marshal())
        i.printf("Marshal (base64): %s...\n", marshal[:20])
}
func certificateType(certType uint32) string {
        switch certType {
        case ssh.UserCert:
                return "User Certificate"
        case ssh.HostCert:
                return "Host Certificate"
        }
        return fmt.Sprintf("Unknown (%d)", certType)
}
func (i *InspectCmd) inspectPKToken(pktStr string) {
        // Parse the PKToken
        pkt, err := pktoken.NewFromCompact([]byte(pktStr))
        if err != nil {
                i.printf("Error parsing PKToken: %v\n", err)
                return
        }
        // Print token structure and metadata
        i.printf("\n--- PKToken Structure ---\n")
        i.printf("Payload:\n")
        i.printJSON(pkt.Payload)
        // Print signature information
        i.printf("\n--- Signature Information ---\n")
        if pkt.Op != nil {
                i.printf("Provider Signature (OP) exists\n")
                hdrs := pkt.Op.ProtectedHeaders()
                if hdrs != nil {
                        i.printJSONObject(hdrs)
                }
        }
        if pkt.Cic != nil {
                i.printf("Client Signature (CIC) exists\n")
                hdrs := pkt.Cic.ProtectedHeaders()
                if hdrs != nil {
                        i.printJSONObject(hdrs)
                }
        }
        if pkt.Cos != nil {
                i.printf("Cosigner Signature (COS) exists\n")
                hdrs := pkt.Cos.ProtectedHeaders()
                if hdrs != nil {
                        i.printJSONObject(hdrs)
                }
        }
        // Print token metadata
        i.printf("\n--- Token Metadata ---\n")
        i.printTokenMetadata(pkt)
}
func (i *InspectCmd) printJSON(data []byte) {
        var obj any
        if err := json.Unmarshal(data, &obj); err != nil {
                i.printf("Error unmarshalling JSON: %v\n", err)
                i.printf("%s\n", string(data))
                return
        }
        i.printJSONObject(obj)
}
func (i *InspectCmd) printJSONObject(obj any) {
        pretty, err := json.MarshalIndent(obj, "", "  ")
        if err != nil {
                i.printf("Error pretty-printing: %v\n", err)
                i.printf("%v\n", obj)
                return
        }
        i.printf("%s\n", string(pretty))
}
func (i *InspectCmd) printTokenMetadata(pkt *pktoken.PKToken) {
        // Extract common token claims
        if issuer, err := pkt.Issuer(); err == nil {
                i.printf("%-19s %s\n", "Issuer:", issuer)
        }
        if aud, err := pkt.Audience(); err == nil {
                i.printf("%-19s %s\n", "Audience:", aud)
        }
        if sub, err := pkt.Subject(); err == nil {
                i.printf("%-19s %s\n", "Subject:", sub)
        }
        if identity, err := pkt.IdentityString(); err == nil {
                i.printf("%-19s %s\n", "Identity:", identity)
        }
        // Print token hash (useful for identifying tokens)
        if hash, err := pkt.Hash(); err == nil {
                i.printf("%-19s %s\n", "Token Hash:", hash)
        }
        // Print provider algorithm if available
        if alg, ok := pkt.ProviderAlgorithm(); ok {
                i.printf("%-19s %s\n", "Provider Algorithm:", alg)
        }
}
		
		// Copyright 2025 OpenPubkey
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package commands
import (
        "bytes"
        "context"
        "crypto"
        "crypto/ecdsa"
        "encoding/base64"
        "encoding/json"
        "encoding/pem"
        "errors"
        "fmt"
        "io"
        "log"
        "os"
        "path/filepath"
        "regexp"
        "slices"
        "strings"
        "time"
        "github.com/lestrrat-go/jwx/v2/jwa"
        "github.com/openpubkey/openpubkey/client"
        "github.com/openpubkey/openpubkey/client/choosers"
        "github.com/openpubkey/openpubkey/oidc"
        "github.com/openpubkey/openpubkey/pktoken"
        "github.com/openpubkey/openpubkey/providers"
        "github.com/openpubkey/openpubkey/util"
        "github.com/openpubkey/opkssh/commands/config"
        "github.com/openpubkey/opkssh/sshcert"
        "github.com/spf13/afero"
        "github.com/thediveo/enumflag/v2"
        "golang.org/x/crypto/ed25519"
        "golang.org/x/crypto/ssh"
)
// KeyType is the algorithm to use for the user's key pair. This is used both by OpenPubkey as algorithm for upk (user public key) and by SSH for public key in the SSH certificate generated by opkssh.
type KeyType enumflag.Flag
const (
        ECDSA KeyType = iota
        ED25519
)
func (k KeyType) String() string {
        switch k {
        case ECDSA:
                return "ecdsa"
        case ED25519:
                return "ed25519"
        default:
                return "unknown"
        }
}
// LoginCmd represents the login command that performs OIDC authentication and generates SSH certificates.
type LoginCmd struct {
        // Inputs
        Fs                    afero.Fs
        AutoRefreshArg        bool   // Automatically refresh PK token after login
        ConfigPathArg         string // Path to the client config file.
        CreateConfigArg       bool   // Creates a client config file if it does not exist
        ConfigureArg          bool   // Apply changes to ssh config and create ~/.ssh/opkssh directory
        LogDirArg             string // Directory to write output logs
        SendAccessTokenArg    bool   // Send the Access Token as well as the PK Token in the SSH cert. The Access Token is used to call the userinfo endpoint to get claims not included in the ID Token
        DisableBrowserOpenArg bool   // Disable opening the browser. Useful for choosing the browser you want to use
        PrintIdTokenArg       bool   // Print out the contents of the id_token. Useful for inspecting claims and troubleshooting
        KeyPathArg            string // Path where SSH private key is written
        ProviderArg           string // OpenID Provider specification in the format: <issuer>,<client_id> or <issuer>,<client_id>,<client_secret> or <issuer>,<client_id>,<client_secret>,<scopes>
        ProviderAliasArg      string
        KeyTypeArg            KeyType
        SSHConfigured         bool
        Verbosity             int                       // Default verbosity is 0, 1 is verbose, 2 is debug
        overrideProvider      *providers.OpenIdProvider // Used in tests to override the provider to inject a mock provider
        // State
        Config *config.ClientConfig
        // Outputs
        pkt        *pktoken.PKToken
        signer     crypto.Signer
        alg        jwa.SignatureAlgorithm
        client     *client.OpkClient
        principals []string
}
// NewLogin creates a new LoginCmd instance with the provided arguments.
func NewLogin(autoRefreshArg bool, configPathArg string, createConfigArg bool, configureArg bool, logDirArg string,
        sendAccessTokenArg bool, disableBrowserOpenArg bool, printIdTokenArg bool,
        providerArg string, keyPathArg string, providerAliasArg string, keyTypeArg KeyType,
) *LoginCmd {
        return &LoginCmd{
                Fs:                    afero.NewOsFs(),
                AutoRefreshArg:        autoRefreshArg,
                ConfigPathArg:         configPathArg,
                CreateConfigArg:       createConfigArg,
                ConfigureArg:          configureArg,
                LogDirArg:             logDirArg,
                SendAccessTokenArg:    sendAccessTokenArg,
                DisableBrowserOpenArg: disableBrowserOpenArg,
                PrintIdTokenArg:       printIdTokenArg,
                KeyPathArg:            keyPathArg,
                ProviderArg:           providerArg,
                ProviderAliasArg:      providerAliasArg,
                KeyTypeArg:            keyTypeArg,
        }
}
func (l *LoginCmd) Run(ctx context.Context) error {
        // If a log directory was provided, write any logs to a file in that directory AND stdout
        if l.LogDirArg != "" {
                logFilePath := filepath.Join(l.LogDirArg, "opkssh.log")
                logFile, err := l.Fs.OpenFile(logFilePath, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0o660)
                if err != nil {
                        log.Printf("Failed to open log for writing: %v \n", err)
                }
                defer logFile.Close()
                multiWriter := io.MultiWriter(os.Stdout, logFile)
                log.SetOutput(multiWriter)
        } else {
                log.SetOutput(os.Stdout)
        }
        if l.Verbosity >= 2 {
                log.Printf("DEBUG: running login command with args: %+v", *l)
        }
        // If the Config has been set in the struct don't replace it. This is useful for testing
        if l.Config == nil {
                if err := config.ResolveClientConfigPath(&l.ConfigPathArg); err != nil {
                        return err
                }
                if _, err := l.Fs.Stat(l.ConfigPathArg); err == nil {
                        if l.CreateConfigArg {
                                log.Printf("--create-config=true but config file already exists at %s", l.ConfigPathArg)
                        }
                        if client_config, err := config.GetClientConfigFromFile((l.ConfigPathArg), l.Fs); err != nil {
                                return err
                        } else {
                                l.Config = client_config
                        }
                } else {
                        if l.CreateConfigArg {
                                return config.CreateDefaultClientConfig(l.ConfigPathArg, l.Fs)
                        } else {
                                log.Printf("failed to find client config file to generate a default config, run `opkssh login --create-config` to create a default config file")
                        }
                        l.Config, err = config.NewClientConfig(config.DefaultClientConfig)
                        if err != nil {
                                return fmt.Errorf("failed to parse default config file: %w", err)
                        }
                }
        }
        if l.ConfigureArg {
                err := l.configureSSH()
                if err != nil {
                        return fmt.Errorf("failed to configure SSH: %w", err)
                }
                return nil
        } else {
                l.checkSSHConfigured()
        }
        if isGitHubEnvironment() {
                l.Config.Providers = append(l.Config.Providers, config.GitHubProviderConfig())
        }
        var provider providers.OpenIdProvider
        if l.overrideProvider != nil {
                provider = *l.overrideProvider
        } else {
                op, chooser, err := l.determineProvider()
                if err != nil {
                        return err
                }
                if chooser != nil {
                        provider, err = chooser.ChooseOp(ctx)
                        if err != nil {
                                return fmt.Errorf("error choosing provider: %w", err)
                        }
                } else if op != nil {
                        provider = op
                } else {
                        return fmt.Errorf("no provider found") // Either the provider or the chooser must be set. If this occurs we have a bug in the code.
                }
        }
        // This arg is true if set, so if it false it hasn't been set and
        // we should use the config value for the matching providing.
        // If it is true we ignore the config
        if !l.SendAccessTokenArg {
                if opConfig, ok := l.Config.GetByIssuer(provider.Issuer()); !ok {
                        // This can happen if the provider is supplied via the command line or environment variables and thus not in the config
                        log.Printf("Warning: could not find issuer %s in client config providers\n", provider.Issuer())
                } else {
                        l.SendAccessTokenArg = opConfig.SendAccessToken
                }
        }
        // Execute login command
        if l.AutoRefreshArg {
                if providerRefreshable, ok := provider.(providers.RefreshableOpenIdProvider); ok {
                        err := l.LoginWithRefresh(ctx, providerRefreshable, l.PrintIdTokenArg, l.KeyPathArg)
                        if err != nil {
                                return fmt.Errorf("error logging in: %w", err)
                        }
                } else {
                        return fmt.Errorf("supplied OpenID Provider (%v) does not support auto-refresh and auto-refresh argument set to true", provider.Issuer())
                }
        } else {
                err := l.Login(ctx, provider, l.PrintIdTokenArg, l.KeyPathArg)
                if err != nil {
                        return fmt.Errorf("error logging in: %w", err)
                }
        }
        return nil
}
func (l *LoginCmd) configureSSH() error {
        userhomeDir, err := os.UserHomeDir()
        if err != nil {
                return fmt.Errorf("failed to get user config dir: %v", err)
        }
        const includeDirective = "Include ~/.ssh/opkssh/config"
        const opkSshDir = ".ssh/opkssh"
        var userSshConfig = filepath.Join(userhomeDir, ".ssh/config")
        var userOpkSshDir = filepath.Join(userhomeDir, opkSshDir)
        var userOpkSshConfig = filepath.Join(userOpkSshDir, "config")
        if _, err := l.Fs.Stat(userOpkSshConfig); err == nil {
                log.Println("--configure but already configured")
        }
        log.Printf("Creating config directory at %s", userOpkSshDir)
        afs := &afero.Afero{Fs: l.Fs}
        err = afs.MkdirAll(userOpkSshDir, 0o0700)
        if err != nil {
                return fmt.Errorf("failed to create opkssh SSH directory: %w", err)
        }
        log.Printf("Creating config file at %s", userOpkSshConfig)
        file, err := afs.OpenFile(userOpkSshConfig, os.O_CREATE, 0o0600)
        if err != nil {
                return fmt.Errorf("failed to create opkssh SSH directory: %w", err)
        }
        defer file.Close()
        log.Printf("Adding include directive to SSH config at %s", "~/.ssh/config")
        content, err := afs.ReadFile(userSshConfig)
        if err != nil && !errors.Is(err, os.ErrNotExist) {
                return fmt.Errorf("failed to read SSH config file: %w", err)
        }
        if strings.Contains(string(content), includeDirective) {
                log.Println("Found include directive file in SSH config, skipping...")
        } else {
                // construct new SSH config
                content = slices.Concat([]byte(includeDirective+"\n\n"), content)
                err = afs.WriteFile(userSshConfig, content, 0o0600)
                if err != nil {
                        return fmt.Errorf("failed to write SSH config file: %w", err)
                }
        }
        l.SSHConfigured = true
        log.Println("Configured SSH identity directory")
        return nil
}
func (l *LoginCmd) checkSSHConfigured() {
        userhomeDir, err := os.UserHomeDir()
        if err != nil {
                log.Printf("Failed to get user config dir: %v", err)
                return
        }
        const includeDirective = "Include ~/.ssh/opkssh/config"
        const opkSshDir = ".ssh/opkssh"
        var userSshConfig = filepath.Join(userhomeDir, ".ssh/config")
        var userOpkSshDir = filepath.Join(userhomeDir, opkSshDir)
        var userOpkSshConfig = filepath.Join(userOpkSshDir, "config")
        afs := &afero.Afero{Fs: l.Fs}
        content, err := afs.ReadFile(userSshConfig)
        if err != nil {
                // no user SSH config, could not have included ours
                return
        }
        if !strings.Contains(string(content), includeDirective) {
                // no include directive
                return
        }
        _, err = afs.Stat(userOpkSshConfig)
        if err != nil {
                // opkssh ssh config missing
                return
        }
        fmt.Println("OPK SSH identity directory is configured")
        l.SSHConfigured = true
}
func (l *LoginCmd) determineProvider() (providers.OpenIdProvider, *choosers.WebChooser, error) {
        openBrowser := !l.DisableBrowserOpenArg
        var defaultProviderAlias string
        var providerConfigs []config.ProviderConfig
        var provider providers.OpenIdProvider
        var err error
        // If the user has supplied commandline arguments for the provider, short circuit and use providerArg
        if l.ProviderArg != "" {
                providerConfig, err := config.NewProviderConfigFromString(l.ProviderArg, false)
                if err != nil {
                        return nil, nil, fmt.Errorf("error parsing provider argument: %w", err)
                }
                if provider, err = providerConfig.ToProvider(openBrowser); err != nil {
                        return nil, nil, fmt.Errorf("error creating provider from config: %w", err)
                } else {
                        return provider, nil, nil
                }
        }
        // Set the default provider from the env variable if specified
        defaultProviderEnv, _ := os.LookupEnv(config.OPKSSH_DEFAULT_ENVVAR)
        providerConfigsEnv, err := config.GetProvidersConfigFromEnv()
        if err != nil {
                return nil, nil, fmt.Errorf("error getting provider config from env: %w", err)
        }
        if l.ProviderAliasArg != "" {
                defaultProviderAlias = l.ProviderAliasArg
        } else if defaultProviderEnv != "" {
                defaultProviderAlias = defaultProviderEnv
        } else if l.Config.DefaultProvider != "" {
                defaultProviderAlias = l.Config.DefaultProvider
        } else {
                defaultProviderAlias = config.WEBCHOOSER_ALIAS
        }
        if providerConfigsEnv != nil {
                providerConfigs = providerConfigsEnv
        } else if len(l.Config.Providers) > 0 {
                providerConfigs = l.Config.Providers
        } else {
                return nil, nil, fmt.Errorf("no providers specified")
        }
        if strings.ToUpper(defaultProviderAlias) != config.WEBCHOOSER_ALIAS {
                providerMap, err := config.CreateProvidersMap(providerConfigs)
                if err != nil {
                        return nil, nil, fmt.Errorf("error creating provider map: %w", err)
                }
                providerConfig, ok := providerMap[defaultProviderAlias]
                if !ok {
                        return nil, nil, fmt.Errorf("error getting provider config for alias %s", defaultProviderAlias)
                }
                provider, err = providerConfig.ToProvider(openBrowser)
                if err != nil {
                        return nil, nil, fmt.Errorf("error creating provider from config: %w", err)
                }
                return provider, nil, nil
        } else {
                // If the default provider is WEBCHOOSER, we need to create a chooser and return it
                var providerList []providers.BrowserOpenIdProvider
                for _, providerConfig := range providerConfigs {
                        op, err := providerConfig.ToProvider(openBrowser)
                        if err != nil {
                                return nil, nil, fmt.Errorf("error creating provider from config: %w", err)
                        }
                        providerList = append(providerList, op.(providers.BrowserOpenIdProvider))
                }
                chooser := choosers.NewWebChooser(
                        providerList, openBrowser,
                )
                return nil, chooser, nil
        }
}
func (l *LoginCmd) login(ctx context.Context, provider providers.OpenIdProvider, printIdToken bool, seckeyPath string) (*LoginCmd, error) {
        var err error
        var alg jwa.SignatureAlgorithm
        switch l.KeyTypeArg {
        case ECDSA:
                alg = jwa.ES256
        case ED25519:
                alg = jwa.EdDSA
        default:
                return nil, fmt.Errorf("unsupported key type (%s); use -t <%s|%s>", l.KeyTypeArg.String(), ECDSA.String(), ED25519.String())
        }
        signer, err := util.GenKeyPair(alg)
        if err != nil {
                return nil, fmt.Errorf("failed to generate keypair: %w", err)
        }
        opkClient, err := client.New(provider, client.WithSigner(signer, alg))
        if err != nil {
                return nil, err
        }
        pkt, err := opkClient.Auth(ctx)
        if err != nil {
                return nil, err
        }
        l.pkt = pkt
        var accessToken []byte
        if l.SendAccessTokenArg {
                accessToken = opkClient.GetAccessToken()
                if accessToken == nil {
                        return nil, fmt.Errorf("access token required but provider (%s) did not set access-token", opkClient.Op.Issuer())
                }
        }
        // If principals is empty the server does not enforce any principal. The OPK
        // verifier should use policy to make this decision.
        principals := []string{}
        certBytes, seckeySshPem, err := createSSHCertWithAccessToken(pkt, accessToken, signer, principals)
        if err != nil {
                return nil, fmt.Errorf("failed to generate SSH cert: %w", err)
        }
        // Write ssh secret key and public key to filesystem
        if seckeyPath != "" {
                // If we have set seckeyPath then write it there
                if err := l.writeKeys(seckeyPath, seckeyPath+"-cert.pub", seckeySshPem, certBytes); err != nil {
                        return nil, fmt.Errorf("failed to write SSH keys to filesystem: %w", err)
                }
        } else if l.SSHConfigured {
                if err := l.writeKeysToOpkSSHDir(seckeySshPem, certBytes); err != nil {
                        return nil, fmt.Errorf("failed to write SSH keys to OPK SSH dir: %w", err)
                }
        } else {
                // If keyPath isn't set then write it to the default location
                if err := l.writeKeysToSSHDir(seckeySshPem, certBytes); err != nil {
                        return nil, fmt.Errorf("failed to write SSH keys to filesystem: %w", err)
                }
        }
        if printIdToken {
                idTokenStr, err := PrettyIdToken(*pkt)
                if err != nil {
                        return nil, fmt.Errorf("failed to format ID Token: %w", err)
                }
                fmt.Printf("id_token:\n%s\n", idTokenStr)
        }
        idStr, err := IdentityString(*pkt)
        if err != nil {
                return nil, fmt.Errorf("failed to parse ID Token: %w", err)
        }
        fmt.Printf("Keys generated for identity\n%s\n", idStr)
        return &LoginCmd{
                pkt:        pkt,
                signer:     signer,
                client:     opkClient,
                alg:        alg,
                principals: principals,
        }, nil
}
// Login performs the OIDC login procedure and creates the SSH certs/keys in the
// default SSH key location.
func (l *LoginCmd) Login(ctx context.Context, provider providers.OpenIdProvider, printIdToken bool, seckeyPath string) error {
        _, err := l.login(ctx, provider, printIdToken, seckeyPath)
        return err
}
// LoginWithRefresh performs the OIDC login procedure, creates the SSH
// certs/keys in the default SSH key location, and continues to run and refresh
// the PKT (and create new SSH certs) indefinitely as its token expires. This
// function only returns if it encounters an error or if the supplied context is
// cancelled.
func (l *LoginCmd) LoginWithRefresh(ctx context.Context, provider providers.RefreshableOpenIdProvider, printIdToken bool, seckeyPath string) error {
        if loginResult, err := l.login(ctx, provider, printIdToken, seckeyPath); err != nil {
                return err
        } else {
                var claims struct {
                        Expiration int64 `json:"exp"`
                }
                if err := json.Unmarshal(loginResult.pkt.Payload, &claims); err != nil {
                        return err
                }
                for {
                        // Sleep until a minute before expiration to give us time to refresh
                        // the token and minimize any interruptions
                        untilExpired := time.Until(time.Unix(claims.Expiration, 0)) - time.Minute
                        log.Printf("Waiting for %v before attempting to refresh id_token...", untilExpired)
                        select {
                        case <-time.After(untilExpired):
                                log.Print("Refreshing id_token...")
                        case <-ctx.Done():
                                return ctx.Err()
                        }
                        refreshedPkt, err := loginResult.client.Refresh(ctx)
                        if err != nil {
                                return err
                        }
                        loginResult.pkt = refreshedPkt
                        var accessToken []byte
                        if l.SendAccessTokenArg {
                                accessToken = loginResult.client.GetAccessToken()
                                if accessToken == nil {
                                        return fmt.Errorf("access token required but provider (%s) did not set access-token on refresh: %w", loginResult.client.Op.Issuer(), err)
                                }
                        }
                        certBytes, seckeySshPem, err := createSSHCertWithAccessToken(loginResult.pkt, accessToken, loginResult.signer, loginResult.principals)
                        if err != nil {
                                return fmt.Errorf("failed to generate SSH cert: %w", err)
                        }
                        // Write ssh secret key and public key to filesystem
                        if seckeyPath != "" {
                                // If we have set seckeyPath then write it there
                                if err := l.writeKeys(seckeyPath, seckeyPath+"-cert.pub", seckeySshPem, certBytes); err != nil {
                                        return fmt.Errorf("failed to write SSH keys to filesystem: %w", err)
                                }
                        } else {
                                // If keyPath isn't set then write it to the default location
                                if err := l.writeKeysToSSHDir(seckeySshPem, certBytes); err != nil {
                                        return fmt.Errorf("failed to write SSH keys to filesystem: %w", err)
                                }
                        }
                        comPkt, err := refreshedPkt.Compact()
                        if err != nil {
                                return err
                        }
                        payloadB64 := payloadFromCompactPkt(comPkt)
                        payload, err := base64.RawURLEncoding.DecodeString(string(payloadB64))
                        if err != nil {
                                return fmt.Errorf("refreshed ID token payload is not base64 encoded: %w", err)
                        }
                        if err = json.Unmarshal(payload, &claims); err != nil {
                                return fmt.Errorf("malformed refreshed ID token payload: %w", err)
                        }
                }
        }
}
func createSSHCert(pkt *pktoken.PKToken, signer crypto.Signer, principals []string) ([]byte, []byte, error) {
        return createSSHCertWithAccessToken(pkt, nil, signer, principals)
}
func createSSHCertWithAccessToken(pkt *pktoken.PKToken, accessToken []byte, signer crypto.Signer, principals []string) ([]byte, []byte, error) {
        cert, err := sshcert.New(pkt, accessToken, principals)
        if err != nil {
                return nil, nil, err
        }
        sshSigner, err := ssh.NewSignerFromSigner(signer)
        if err != nil {
                return nil, nil, err
        }
        var keyAlgos []string
        switch signer.(type) {
        case *ecdsa.PrivateKey:
                keyAlgos = []string{ssh.KeyAlgoECDSA256}
        case ed25519.PrivateKey:
                keyAlgos = []string{ssh.KeyAlgoED25519}
        default:
                return nil, nil, fmt.Errorf("unsupported key type: %T", signer)
        }
        signerMas, err := ssh.NewSignerWithAlgorithms(sshSigner.(ssh.AlgorithmSigner), keyAlgos)
        if err != nil {
                return nil, nil, err
        }
        sshCert, err := cert.SignCert(signerMas)
        if err != nil {
                return nil, nil, err
        }
        certBytes := ssh.MarshalAuthorizedKey(sshCert)
        // Remove newline character that MarshalAuthorizedKey() adds
        certBytes = certBytes[:len(certBytes)-1]
        seckeySsh, err := ssh.MarshalPrivateKey(signer, "openpubkey cert")
        if err != nil {
                return nil, nil, err
        }
        seckeySshBytes := pem.EncodeToMemory(seckeySsh)
        return certBytes, seckeySshBytes, nil
}
func (l *LoginCmd) writeKeysToOpkSSHDir(secKeyPem []byte, certBytes []byte) error {
        const (
                opkSshPath     = ".ssh/opkssh"
                configFileName = "config"
        )
        userhomeDir, err := os.UserHomeDir()
        if err != nil {
                return err
        }
        opkSshUserPath := filepath.Join(userhomeDir, opkSshPath)
        opkSshConfigPath := filepath.Join(opkSshUserPath, configFileName)
        sshKeyName := l.makeSSHKeyFileName(l.pkt)
        privKeyPath := filepath.Join(opkSshUserPath, sshKeyName)
        pubKeyPath := filepath.Join(privKeyPath + "-cert.pub")
        // get key comment
        issuer, err := l.pkt.Issuer()
        if err != nil {
                issuer = "unknown"
        }
        audience, err := l.pkt.Audience()
        if err != nil {
                audience = "unknown"
        }
        comment := " openpubkey: " + issuer + " " + audience
        // add key to config
        afs := &afero.Afero{Fs: l.Fs}
        configContent, err := afs.ReadFile(opkSshConfigPath)
        if err != nil {
                return fmt.Errorf("failed to read opk ssh config file (%s): %w", opkSshConfigPath, err)
        }
        if !strings.Contains(string(configContent), privKeyPath) {
                configContent = slices.Concat(
                        []byte("IdentityFile "+privKeyPath+"\n"),
                        configContent,
                )
        }
        err = afs.WriteFile(opkSshConfigPath, configContent, 0600)
        if err != nil {
                return fmt.Errorf("failed to write opk ssh config file (%s): %w", opkSshConfigPath, err)
        }
        // write ssh key files
        return l.writeKeysComment(privKeyPath, pubKeyPath, secKeyPem, certBytes, comment)
}
func (l *LoginCmd) writeKeysToSSHDir(seckeySshPem []byte, certBytes []byte) error {
        homePath, err := os.UserHomeDir()
        if err != nil {
                return err
        }
        sshPath := filepath.Join(homePath, ".ssh")
        // Make ~/.ssh if folder does not exist
        err = l.Fs.MkdirAll(sshPath, os.ModePerm)
        if err != nil {
                return err
        }
        // For ssh to automatically find the key created by openpubkey when
        // connecting, we use one of the default ssh key paths. However, the file
        // might contain an existing key. We will overwrite the key if it was
        // generated by openpubkey  which we check by looking at the associated
        // comment. If the comment is equal to "openpubkey", we overwrite the file
        // with a new key.
        var keyFileNames []string
        switch l.KeyTypeArg {
        case ECDSA:
                keyFileNames = []string{"id_ecdsa", "id_ecdsa_sk"}
        case ED25519:
                keyFileNames = []string{"id_ed25519", "id_ed25519_sk"}
        default:
                return fmt.Errorf("key type (%s) has no default output file name; use -i <filePath>", l.KeyTypeArg.String())
        }
        for _, keyFilename := range keyFileNames {
                seckeyPath := filepath.Join(sshPath, keyFilename)
                pubkeyPath := seckeyPath + "-cert.pub"
                if !l.fileExists(seckeyPath) {
                        // If ssh key file does not currently exist, we don't have to worry about overwriting it
                        return l.writeKeys(seckeyPath, pubkeyPath, seckeySshPem, certBytes)
                } else if !l.fileExists(pubkeyPath) {
                        continue
                } else {
                        // If the ssh key file does exist, check if it was generated by openpubkey, if it was then it is safe to overwrite
                        afs := &afero.Afero{Fs: l.Fs}
                        sshPubkey, err := afs.ReadFile(pubkeyPath)
                        if err != nil {
                                log.Println("Failed to read:", pubkeyPath)
                                continue
                        }
                        _, comment, _, _, err := ssh.ParseAuthorizedKey(sshPubkey)
                        if err != nil {
                                log.Println("Failed to parse:", pubkeyPath)
                                continue
                        }
                        // If the key comment is "openpubkey" then we generated it
                        if comment == "openpubkey" {
                                return l.writeKeys(seckeyPath, pubkeyPath, seckeySshPem, certBytes)
                        }
                }
        }
        return fmt.Errorf("no default ssh key file free for openpubkey")
}
func (l *LoginCmd) writeKeys(seckeyPath string, pubkeyPath string, seckeySshPem []byte, certBytes []byte) error {
        // Write ssh secret key to filesystem
        afs := &afero.Afero{Fs: l.Fs}
        if err := afs.WriteFile(seckeyPath, seckeySshPem, 0o600); err != nil {
                return err
        }
        fmt.Printf("Writing opk ssh public key to %s and corresponding secret key to %s\n", pubkeyPath, seckeyPath)
        certBytes = append(certBytes, []byte(" openpubkey")...)
        // Write ssh public key (certificate) to filesystem
        return afs.WriteFile(pubkeyPath, certBytes, 0o644)
}
func (l *LoginCmd) writeKeysComment(seckeyPath string, pubkeyPath string, seckeySshPem []byte, certBytes []byte, pubKeyComment string) error {
        // Write ssh secret key to filesystem
        afs := &afero.Afero{Fs: l.Fs}
        if err := afs.WriteFile(seckeyPath, seckeySshPem, 0o600); err != nil {
                return err
        }
        fmt.Printf("Writing opk ssh public key to %s and corresponding secret key to %s\n", pubkeyPath, seckeyPath)
        certBytes = append(certBytes, ' ')
        certBytes = append(certBytes, pubKeyComment...)
        // Write ssh public key (certificate) to filesystem
        return afs.WriteFile(pubkeyPath, certBytes, 0o644)
}
func (l *LoginCmd) makeSSHKeyFileName(pkt *pktoken.PKToken) string {
        regex := regexp.MustCompile(`[^a-zA-Z0-9_\-.]+`)
        issuer, err := pkt.Issuer()
        if err != nil {
                issuer = "unknown"
        }
        issuer, _ = strings.CutPrefix(issuer, "https://")
        audience, err := pkt.Audience()
        if err != nil {
                audience = "unknown"
        }
        // shorten clientID if it is too long
        if len(audience) > 20 {
                audience = audience[:20]
        }
        keyName := issuer + "-" + audience
        keyName = regex.ReplaceAllString(keyName, "_")
        return keyName
}
func (l *LoginCmd) fileExists(fPath string) bool {
        _, err := l.Fs.Open(fPath)
        return !errors.Is(err, os.ErrNotExist)
}
// IdentityString returns a string representation of the identity from the PK Token.
// e.g "Email, sub, issuer, audience"
func IdentityString(pkt pktoken.PKToken) (string, error) {
        idt, err := oidc.NewJwt(pkt.OpToken)
        if err != nil {
                return "", err
        }
        claims := idt.GetClaims()
        if claims.Email == "" {
                return fmt.Sprintf(`WARNING: Email claim is missing from ID token. Policies based on email will not work.
Check if your client config (~/.opk/config.yml) has the correct scopes configured for this OpenID Provider.
Sub, issuer, audience:
%s %s %s`, claims.Subject, claims.Issuer, claims.Audience), nil
        } else {
                return fmt.Sprintf(`Email, sub, issuer, audience: 
%s %s %s %s`, claims.Email, claims.Subject, claims.Issuer, claims.Audience), nil
        }
}
// PrettyIdToken returns a pretty-printed JSON representation of the ID Token claims.
func PrettyIdToken(pkt pktoken.PKToken) (string, error) {
        idt, err := oidc.NewJwt(pkt.OpToken)
        if err != nil {
                return "", err
        }
        idtJson, err := json.MarshalIndent(idt.GetClaims(), "", "    ")
        if err != nil {
                return "", err
        }
        return string(idtJson[:]), nil
}
func isGitHubEnvironment() bool {
        return os.Getenv("ACTIONS_ID_TOKEN_REQUEST_URL") != "" &&
                os.Getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN") != ""
}
// payloadFromCompactPkt extracts the payload from a compact PK Token which
// is always the second part of the '.' separated string.
func payloadFromCompactPkt(compactPkt []byte) []byte {
        parts := bytes.Split(compactPkt, []byte("."))
        return parts[1]
}
		
		// Copyright 2025 OpenPubkey
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
//go:build linux || darwin
package commands
import (
        "errors"
        "fmt"
        "os"
        "os/user"
        "path/filepath"
        "regexp"
        "strconv"
        "syscall"
        "github.com/openpubkey/opkssh/policy/files"
)
// ReadHome is used to read the home policy file for the user with
// the specified username. This is used when opkssh is called by
// AuthorizedKeysCommand as the opksshuser and needs to use sudoer
// access to read the home policy file (`/home/<username>/opk/auth_id`).
// This function is only available on Linux and Darwin because it relies on
// syscall.Stat_t to determine the owner of the file.
func ReadHome(username string) ([]byte, error) {
        if matched, _ := regexp.MatchString("^[a-z0-9_\\-.]+$", username); !matched {
                return nil, fmt.Errorf("%s is not a valid linux username", username)
        }
        userObj, err := user.Lookup(username)
        if err != nil {
                return nil, fmt.Errorf("failed to find user %s", username)
        }
        homePolicyPath := filepath.Join(userObj.HomeDir, ".opk", "auth_id")
        // Security critical: We reading this file as `sudo -u opksshuser`
        // and opksshuser has elevated permissions to read any file whose
        // path matches `/home/*/opk/auth_id`. We need to be cautious we do follow
        // a symlink as it could be to a file the user is not permitted to read.
        // This would not permit the user to read the file, but they might be able
        // to determine the existence of the file. We use O_NOFOLLOW to prevent
        // following symlinks.
        file, err := os.OpenFile(homePolicyPath, os.O_RDONLY|syscall.O_NOFOLLOW, 0)
        if err != nil {
                if errors.Is(err, syscall.ELOOP) {
                        return nil, fmt.Errorf("home policy file %s is a symlink, symlink are unsafe in this context", homePolicyPath)
                }
                return nil, fmt.Errorf("failed to open %s, %v", homePolicyPath, err)
        }
        defer file.Close()
        if fileInfo, err := file.Stat(); err != nil {
                return nil, fmt.Errorf("failed to get info on file %s", homePolicyPath)
        } else if stat, ok := fileInfo.Sys().(*syscall.Stat_t); !ok { // This syscall.Stat_t is doesn't work on Windows
                return nil, fmt.Errorf("failed to stat file %s", homePolicyPath)
        } else {
                // We want to ensure that the file is owned by the correct user and has the correct permissions.
                requiredOwnerUid := userObj.Uid
                fileOwnerUID := strconv.FormatUint(uint64(stat.Uid), 10)
                fileOwner, err := user.LookupId(fileOwnerUID)
                if err != nil {
                        return nil, fmt.Errorf("failed to find username for UID %s for file %s", fileOwnerUID, homePolicyPath)
                }
                if fileOwnerUID != userObj.Uid || fileOwner.Username != username {
                        return nil, fmt.Errorf("unsafe file permissions on %s expected file owner %s (UID %s) got %s (UID %s)",
                                homePolicyPath, username, requiredOwnerUid, fileOwner.Username, fileOwnerUID)
                }
                if fileInfo.Mode().Perm() != files.ModeHomePerms {
                        return nil, fmt.Errorf("unsafe file permissions for %s got %o expected %o", homePolicyPath, fileInfo.Mode().Perm(), files.ModeHomePerms)
                }
                fileBytes, err := os.ReadFile(homePolicyPath)
                if err != nil {
                        return nil, fmt.Errorf("failed to read %s, %v", homePolicyPath, err)
                }
                return fileBytes, nil
        }
}
		
		// Copyright 2025 OpenPubkey
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package commands
import (
        "context"
        "fmt"
        "io/fs"
        "net/http"
        "github.com/openpubkey/openpubkey/pktoken"
        "github.com/openpubkey/openpubkey/verifier"
        "github.com/openpubkey/opkssh/commands/config"
        "github.com/openpubkey/opkssh/policy"
        "github.com/openpubkey/opkssh/policy/files"
        "github.com/openpubkey/opkssh/sshcert"
        "github.com/spf13/afero"
        "golang.org/x/crypto/ssh"
)
// PolicyEnforcerFunc returns nil if the supplied PK token is permitted to login as
// username. Otherwise, an error is returned indicating the reason for rejection
type PolicyEnforcerFunc func(username string, pkt *pktoken.PKToken, userInfo string, sshCert string, keyType string, denyList policy.DenyList) error
// VerifyCmd provides functionality to verify OPK tokens contained in SSH
// certificates and authorize requests to SSH as a specific username using a
// configurable authorization system. It is designed to be used in conjunction
// with sshd's AuthorizedKeysCommand feature.
type VerifyCmd struct {
        Fs afero.Fs
        // PktVerifier is responsible for verifying the PK token
        // contained in the SSH certificate
        PktVerifier verifier.Verifier
        // CheckPolicy determines whether the verified PK token is permitted to SSH as a
        // specific user
        CheckPolicy PolicyEnforcerFunc
        // ConfigPathArg is the path to the server config file
        ConfigPathArg string
        // filePermChecker is used to check the file permissions of the config file
        filePermChecker files.PermsChecker
        // HTTPClient can be mocked using a roundtripper in tests
        HttpClient *http.Client
        // denyList is populated from ServerConfig after successful parsing
        denyList policy.DenyList
}
// NewVerifyCmd creates a new VerifyCmd instance with the provided arguments.
func NewVerifyCmd(pktVerifier verifier.Verifier, checkPolicy PolicyEnforcerFunc, configPathArg string) *VerifyCmd {
        fs := afero.NewOsFs()
        return &VerifyCmd{
                Fs:            fs,
                PktVerifier:   pktVerifier,
                CheckPolicy:   checkPolicy,
                ConfigPathArg: configPathArg,
                filePermChecker: files.PermsChecker{
                        Fs:        fs,
                        CmdRunner: files.ExecCmd,
                },
        }
}
// This function is called by the SSH server as the AuthorizedKeysCommand:
//
// By default, the following lines are added to the sshd_config at /etc/ssh/sshd_config.d/60-opk-ssh.conf:
//
//        AuthorizedKeysCommand /usr/local/bin/opkssh verify %u %k %t
//        AuthorizedKeysCommandUser opksshuser
//
// The parameters specified in the config map the parameters sent to the function below.
// We prepend "Arg" to specify which ones are arguments sent by sshd. They are:
//
//        %u The username (requested principal) - userArg
//        %k The base64-encoded public key for authentication - certB64Arg - the public key is also a certificate
//        %t The public key type - typArg - in this case a certificate being used as a public key
//
// AuthorizedKeysCommand verifies the OPK PK token contained in the base64-encoded SSH pubkey;
// the pubkey is expected to be an SSH certificate. pubkeyType is used to
// determine how to parse the pubkey as one of the SSH certificate types.
//
// This function:
// 1. Verifying the PK token with the OP (OpenID Provider)
// 2. Enforcing policy by checking if the identity is allowed to assume
// the username (principal) requested.
//
// If all steps of verification succeed, then the expected authorized_keys file
// format string is returned (i.e. the expected line to produce on standard
// output when using sshd's AuthorizedKeysCommand feature). Otherwise, a non-nil
// error is returned.
func (v *VerifyCmd) AuthorizedKeysCommand(ctx context.Context, userArg string, typArg string, certB64Arg string) (string, error) {
        // Parse the b64 pubkey and expect it to be an ssh certificate
        cert, err := sshcert.NewFromAuthorizedKey(typArg, certB64Arg)
        if err != nil {
                return "", err
        }
        if pkt, err := cert.VerifySshPktCert(ctx, v.PktVerifier); err != nil { // Verify the PKT contained in the cert
                return "", err
        } else {
                userInfo := ""
                if accessToken := cert.GetAccessToken(); accessToken != "" {
                        if userInfoRet, err := v.UserInfoLookup(ctx, pkt, accessToken); err == nil {
                                // userInfo is optional so we should not fail if we can't access it
                                userInfo = userInfoRet
                        }
                }
                if err := v.CheckPolicy(userArg, pkt, userInfo, certB64Arg, typArg, v.denyList); err != nil {
                        return "", err
                } else { // Success!
                        // sshd expects the public key in the cert, not the cert itself. This
                        // public key is key of the CA that signs the cert, in our setting there
                        // is no CA.
                        pubkeyBytes := ssh.MarshalAuthorizedKey(cert.SshCert.SignatureKey)
                        return "cert-authority " + string(pubkeyBytes), nil
                }
        }
}
// ReadFromServerConfig sets the environment variables specified in the server config file
// and assigns configured deny lists to VerifyCmd's denyList
func (v *VerifyCmd) ReadFromServerConfig() error {
        var configBytes []byte
        // Load the file from the filesystem
        afs := &afero.Afero{Fs: v.Fs}
        configBytes, err := afs.ReadFile(v.ConfigPathArg)
        if err != nil {
                return fmt.Errorf("failed to read config file: %w", err)
        }
        err = v.filePermChecker.CheckPerm(v.ConfigPathArg, []fs.FileMode{0640}, "root", "opksshuser")
        if err != nil {
                return err
        }
        serverConfig, err := config.NewServerConfig(configBytes)
        if err != nil {
                return fmt.Errorf("failed to parse config file: %w", err)
        }
        v.denyList = policy.DenyList{
                Emails: serverConfig.DenyEmails,
                Users:  serverConfig.DenyUsers,
        }
        return serverConfig.SetEnvVars()
}
func (v *VerifyCmd) UserInfoLookup(ctx context.Context, pkt *pktoken.PKToken, accessToken string) (string, error) {
        ui, err := verifier.NewUserInfoRequester(pkt, accessToken)
        if err != nil {
                return "", err
        }
        ui.HttpClient = v.HttpClient
        return ui.Request(ctx)
}
// OpkPolicyEnforcerAuthFunc returns an opkssh policy.Enforcer that can be
// used in the opkssh verify command.
func OpkPolicyEnforcerFunc(username string) PolicyEnforcerFunc {
        policyEnforcer := &policy.Enforcer{
                PolicyLoader: policy.NewMultiPolicyLoader(username, policy.ReadWithSudoScript),
        }
        return policyEnforcer.CheckPolicy
}
		
		// Copyright 2025 OpenPubkey
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
/*
OPKSSH is a command-line tool that allows users to authenticate with OpenID Connect providers and generate SSH keys for secure access to servers.
*/
package main
import (
        "context"
        "errors"
        "fmt"
        "log"
        "os"
        "os/exec"
        "os/signal"
        "regexp"
        "strings"
        "syscall"
        "text/tabwriter"
        "github.com/openpubkey/opkssh/commands"
        config "github.com/openpubkey/opkssh/commands/config"
        "github.com/openpubkey/opkssh/policy"
        "github.com/openpubkey/opkssh/policy/files"
        "github.com/spf13/afero"
        "github.com/spf13/cobra"
        "github.com/spf13/cobra/doc"
        "github.com/thediveo/enumflag/v2"
        "golang.org/x/term"
)
var (
        // These can be overridden at build time using ldflags. For example:
        // go build -v -o /usr/local/bin/opkssh -ldflags "-X main.Version=version"
        Version           = "unversioned"
        logFilePathServer = "/var/log/opkssh.log" // Remember if you change this, change it in the install script as well
)
func main() {
        os.Exit(run())
}
func run() int {
        rootCmd := &cobra.Command{
                SilenceUsage: true,
                Use:          "opkssh",
                Short:        "SSH with OpenPubkey",
                Version:      Version,
                Long: `SSH with OpenPubkey
This program allows users to:
  - Login and create SSH key pairs using their OpenID Connect identity
  - Add policies to auth_id policy files
  - Verify OpenPubkey SSH certificates for use with sshd's AuthorizedKeysCommand`,
                Example: `  opkssh login
  opkssh add root alice@example.com https://accounts.google.com`,
                RunE: func(cmd *cobra.Command, args []string) error {
                        return cmd.Help()
                },
        }
        rootCmd.CompletionOptions.DisableDefaultCmd = true
        addCmd := &cobra.Command{
                SilenceUsage: true,
                Use:          "add <principal> <email|sub|group> <issuer>",
                Short:        "Appends new rule to the policy file",
                Long: `Add appends a new policy entry in the auth_id policy file granting SSH access to the specified email or subscriber ID (sub) or group.
It first attempts to write to the system-wide file (/etc/opk/auth_id). If it lacks permissions to update this file it falls back to writing to the user-specific file (~/.opk/auth_id).
Arguments:
  principal            The target user account (requested principal).
  email|sub|group      Email address, subscriber ID or group authorized to assume this principal. If using an OIDC group, the argument needs to be in the format of oidc:groups:<groupId>.
  issuer               OpenID Connect provider (issuer) URL associated with the email/sub/group.
`,
                Args: cobra.ExactArgs(3),
                Example: `  opkssh add root alice@example.com https://accounts.google.com
  opkssh add alice 103030642802723203118 https://accounts.google.com
  opkssh add developer oidc:groups:developer https://accounts.google.com`,
                RunE: func(cmd *cobra.Command, args []string) error {
                        inputPrincipal := args[0]
                        inputEmail := args[1]
                        inputIssuer := args[2]
                        // Convenience aliases to save user time (who is going to remember the hideous Azure issuer string)
                        switch inputIssuer {
                        case "google":
                                inputIssuer = "https://accounts.google.com"
                        case "azure", "microsoft":
                                inputIssuer = "https://login.microsoftonline.com/9188040d-6c67-4c5b-b112-36a304b66dad/v2.0"
                        case "gitlab":
                                inputIssuer = "https://gitlab.com"
                        case "hello":
                                inputIssuer = "https://issuer.hello.coop"
                        }
                        add := commands.AddCmd{
                                HomePolicyLoader:   policy.NewHomePolicyLoader(),
                                SystemPolicyLoader: policy.NewSystemPolicyLoader(),
                                Username:           inputPrincipal,
                        }
                        policyFilePath, err := add.Run(inputPrincipal, inputEmail, inputIssuer)
                        if err != nil {
                                fmt.Fprintf(os.Stderr, "Failed to add to policy: %v\n", err)
                                return err
                        }
                        fmt.Fprintf(os.Stdout, "Successfully added new policy to %s\n", policyFilePath)
                        return nil
                },
        }
        rootCmd.AddCommand(addCmd)
        inspectCmd := &cobra.Command{
                SilenceUsage: true,
                Use:          "inspect <path>",
                Short:        "Inspect and view details of an opkssh generated SSH key",
                Example:      "  opkssh inspect ~/.ssh/id_ecdsa_sk-cert.pub",
                RunE: func(cmd *cobra.Command, args []string) error {
                        keyPathArg := args[0]
                        inspect := commands.NewInspectCmd(keyPathArg, cmd.OutOrStdout())
                        if err := inspect.Run(); err != nil {
                                log.Println("Error executing inspect command:", err)
                                return err
                        }
                        return nil
                },
                Args: cobra.ExactArgs(1),
        }
        rootCmd.AddCommand(inspectCmd)
        var autoRefreshArg bool
        var configPathArg string
        var createConfigArg bool
        var configureArg bool
        var logDirArg string
        var providerArg string
        var sendAccessTokenArg bool
        var disableBrowserOpenArg bool
        var printIdTokenArg bool
        var keyPathArg string
        var keyTypeArg commands.KeyType
        loginCmd := &cobra.Command{
                SilenceUsage: true,
                Use:          "login [alias]",
                Short:        "Authenticate with an OpenID Provider to generate an SSH key for opkssh",
                Long: `Login creates opkssh SSH keys
Login generates a key pair, then opens a browser to authenticate the user with the OpenID Provider. Upon successful authentication, opkssh creates an SSH public key (~/.ssh/id_ecdsa) containing the user's PK token. By default, this SSH key expires after 24 hours, after which the user must run "opkssh login" again to generate a new key.
Users can then SSH into servers configured to use opkssh as the AuthorizedKeysCommand. The server verifies the PK token and grants access if the token is valid and the user is authorized per the auth_id policy.
Arguments:
  alias      The provider alias to use. If not specified, the OPKSSH_DEFAULT provider will be used. The aliases are defined by the OPKSSH_PROVIDERS environment variable. The format is <alias>,<issuer>,<client_id>,<client_secret>,<scopes>
`,
                Example: `  opkssh login
  opkssh login google
  opkssh login --provider=<issuer>,<client_id>,<client_secret>,<scopes>`,
                RunE: func(cmd *cobra.Command, args []string) error {
                        ctx, cancel := context.WithCancel(context.Background())
                        defer cancel()
                        sigs := make(chan os.Signal, 1)
                        signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
                        go func() {
                                <-sigs
                                cancel()
                        }()
                        var providerAliasArg string
                        if len(args) > 0 {
                                providerAliasArg = args[0]
                        }
                        login := commands.NewLogin(autoRefreshArg, configPathArg, createConfigArg, configureArg, logDirArg, sendAccessTokenArg, disableBrowserOpenArg, printIdTokenArg, providerArg, keyPathArg, providerAliasArg, keyTypeArg)
                        if err := login.Run(ctx); err != nil {
                                log.Println("Error executing login command:", err)
                                return err
                        }
                        return nil
                },
                Args: cobra.MaximumNArgs(1),
        }
        // Define flags for login.
        loginCmd.Flags().BoolVar(&autoRefreshArg, "auto-refresh", false, "Automatically refresh PK token after login")
        loginCmd.Flags().StringVar(&configPathArg, "config-path", "", "Path to the client config file. Default: ~/.opk/config.yml on linux and %APPDATA%\\.opk\\config.yml on windows")
        loginCmd.Flags().BoolVar(&createConfigArg, "create-config", false, "Creates a client config file if it does not exist")
        loginCmd.Flags().BoolVar(&configureArg, "configure", false, "Apply changes to ssh config and create ~/.ssh/opkssh directory")
        loginCmd.Flags().StringVar(&logDirArg, "log-dir", "", "Directory to write output logs")
        loginCmd.Flags().BoolVar(&disableBrowserOpenArg, "disable-browser-open", false, "Set this flag to disable opening the browser. Useful for choosing the browser you want to use")
        loginCmd.Flags().BoolVar(&printIdTokenArg, "print-id-token", false, "Set this flag to print out the contents of the id_token. Useful for inspecting claims")
        loginCmd.Flags().BoolVar(&sendAccessTokenArg, "send-access-token", false, "Set this flag to send the Access Token as well as the PK Token in the SSH cert. The Access Token is used to call the userinfo endpoint to get claims not included in the ID Token")
        loginCmd.Flags().StringVar(&providerArg, "provider", "", "OpenID Provider specification in the format: <issuer>,<client_id> or <issuer>,<client_id>,<client_secret> or <issuer>,<client_id>,<client_secret>,<scopes>")
        loginCmd.Flags().StringVarP(&keyPathArg, "private-key-file", "i", "", "Path where private keys is written")
        loginCmd.Flags().VarP(enumflag.New(&keyTypeArg, "Key Type", map[commands.KeyType][]string{commands.ECDSA: {commands.ECDSA.String()}, commands.ED25519: {commands.ED25519.String()}}, enumflag.EnumCaseInsensitive), "key-type", "t", "Type of key to generate")
        rootCmd.AddCommand(loginCmd)
        readhomeCmd := &cobra.Command{
                SilenceUsage: true,
                Use:          "readhome <principal>",
                Short:        "Read the principal's home policy file",
                Long: `Read the principal's policy file (/home/<principal>/.opk/auth_id).
You should not call this command directly. It is called by the opkssh verify command as part of the AuthorizedKeysCommand process to read the user's policy  (principals) home file (~/.opk/auth_id) with sudoer permissions. This allows us to use an unprivileged user as the AuthorizedKeysCommand user.
`,
                Args:    cobra.ExactArgs(1),
                Example: `  opkssh readhome alice`,
                RunE: func(cmd *cobra.Command, args []string) error {
                        userArg := os.Args[2]
                        if fileBytes, err := commands.ReadHome(userArg); err != nil {
                                fmt.Fprintf(os.Stderr, "Failed to read user's home policy file: %v\n", err)
                                return err
                        } else {
                                fmt.Fprint(os.Stdout, string(fileBytes))
                                return nil
                        }
                },
        }
        rootCmd.AddCommand(readhomeCmd)
        var serverConfigPathArg string
        verifyCmd := &cobra.Command{
                SilenceUsage: true,
                Use:          "verify <principal> <cert> <key_type>",
                Short:        "Verify an SSH key (used by sshd AuthorizedKeysCommand)",
                Long: `Verify extracts a PK token from a base64-encoded SSH certificate and verifies it against policy. It expects an allowed provider file at /etc/opk/providers and a user policy file at either /etc/opk/auth_id or ~/.opk/auth_id.
This command is intended to be called by sshd as an AuthorizedKeysCommand:
  https://man.openbsd.org/sshd_config#AuthorizedKeysCommand
During installation, opkssh typically adds these lines to /etc/ssh/sshd_config:
  AuthorizedKeysCommand /usr/local/bin/opkssh verify %%u %%k %%t
  AuthorizedKeysCommandUser opksshuser
Where the tokens in /etc/ssh/sshd_config are defined as:
  %%u   Target username (requested principal)
  %%k   Base64-encoded SSH public key (SSH certificate) provided for authentication
  %%t   Public key type (SSH certificate format, e.g., ecdsa-sha2-nistp256-cert-v01@openssh.com)
Verification checks performed:
  1. Ensures the PK token is properly formed, signed, and issued by the specified OpenID Provider (OP).
  2. Confirms the PK token's issue (iss) and client ID (audience) are listed in the allowed provider file (/etc/opk/providers) and the token is not expired.
  3. Validates the identity (email or sub) in the PK token against user policies (/etc/opk/auth_id or ~/.opk/auth_id) to ensure it can assume the requested username (principal).
If all checks pass, Verify authorizes the SSH connection.
Arguments:
  principal    Target username.
  cert         Base64-encoded SSH certificate.
  key_type     SSH certificate key type (e.g., ecdsa-sha2-nistp256-cert-v01@openssh.com)`,
                Args:    cobra.ExactArgs(3),
                Example: `  opkssh verify root <base64-encoded-cert> ecdsa-sha2-nistp256-cert-v01@openssh.com`,
                RunE: func(cmd *cobra.Command, args []string) error {
                        ctx := context.Background()
                        // Setup logger
                        logFile, err := os.OpenFile(logFilePathServer, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0660) // Owner and group can read/write
                        if err != nil {
                                fmt.Fprintf(os.Stderr, "Error opening log file: %v\n", err)
                                // It could be very difficult to figure out what is going on if the log file was deleted. Hopefully this message saves someone an hour of debugging.
                                fmt.Fprintf(os.Stderr, "Check if log exists at %v, if it does not create it with permissions: chown root:opksshuser %v; chmod 660 %v\n", logFilePathServer, logFilePathServer, logFilePathServer)
                        } else {
                                defer logFile.Close()
                                log.SetOutput(logFile)
                        }
                        // Logs if using an unsupported OpenSSH version
                        checkOpenSSHVersion()
                        // The "AuthorizedKeysCommand" func is designed to be used by sshd and specified as an AuthorizedKeysCommand
                        // ref: https://man.openbsd.org/sshd_config#AuthorizedKeysCommand
                        log.Println(strings.Join(os.Args, " "))
                        userArg := args[0]
                        certB64Arg := args[1]
                        typArg := args[2]
                        providerPolicyPath := "/etc/opk/providers"
                        providerPolicy, err := policy.NewProviderFileLoader().LoadProviderPolicy(providerPolicyPath)
                        if err != nil {
                                log.Println("Failed to open /etc/opk/providers:", err)
                                return err
                        }
                        printConfigProblems()
                        log.Println("Providers loaded: ", providerPolicy.ToString())
                        pktVerifier, err := providerPolicy.CreateVerifier()
                        if err != nil {
                                log.Println("Failed to create pk token verifier (likely bad configuration):", err)
                                return err
                        }
                        v := commands.NewVerifyCmd(*pktVerifier, commands.OpkPolicyEnforcerFunc(userArg), serverConfigPathArg)
                        if err := v.ReadFromServerConfig(); err != nil {
                                log.Println("Failed to set environment variables in config:", err)
                        }
                        if authKey, err := v.AuthorizedKeysCommand(ctx, userArg, typArg, certB64Arg); err != nil {
                                log.Println("failed to verify:", err)
                                return err
                        } else {
                                log.Println("successfully verified")
                                // sshd is awaiting a specific line, which we print here. Printing anything else before or after will break our solution
                                fmt.Println(authKey)
                                return nil
                        }
                },
        }
        verifyCmd.Flags().StringVar(&serverConfigPathArg, "config-path", "/etc/opk/config.yml", "Path to the server config file. Default: /etc/opk/config.yml.")
        rootCmd.AddCommand(verifyCmd)
        clientCmd := &cobra.Command{
                Use:     "client [subcommand]",
                Short:   "Interact with client configuration",
                Example: `  opkssh client provider list`,
                Args:    cobra.ExactArgs(0),
        }
        providerCmd := &cobra.Command{
                Use:     "provider [subcommand]",
                Short:   "Interact with provider configuration",
                Example: `  opkssh client provider list`,
                Args:    cobra.ExactArgs(0),
        }
        providerListCmd := &cobra.Command{
                Use:     "list",
                Short:   "List configured providers",
                Example: `  opkssh client provider list`,
                Args:    cobra.ExactArgs(0),
                RunE: func(cmd *cobra.Command, args []string) error {
                        client_config, err := config.GetClientConfigFromFile(configPathArg, afero.NewOsFs())
                        if err != nil {
                                log.Fatal("Unable to load providers. ", err)
                        }
                        isTTY := term.IsTerminal(int(os.Stdout.Fd()))
                        var w *tabwriter.Writer
                        if isTTY {
                                // Nice aligned table for TTY output
                                w = tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
                                fmt.Fprintln(w, "Alias\tIssuer")
                                fmt.Fprintln(w, "-----\t------")
                        } else {
                                // Simpler formatting for non-TTY (e.g., when piping to a file)
                                w = tabwriter.NewWriter(os.Stdout, 0, 0, 1, ' ', tabwriter.DiscardEmptyColumns)
                        }
                        for _, p := range client_config.Providers {
                                for _, alias := range p.AliasList {
                                        fmt.Fprintf(w, "%s\t%s\n", alias, p.Issuer)
                                }
                        }
                        w.Flush()
                        // and lets check it can be loaded into a map, after we print the contents
                        if _, err = config.CreateProvidersMap(client_config.Providers); err != nil {
                                log.Fatal("Unable to parse providers. ", err)
                        }
                        return nil
                },
        }
        providerListCmd.Flags().StringVar(&configPathArg, "config-path", "", "Path to the client config file. Default: ~/.opk/config.yml on linux and %APPDATA%\\.opk\\config.yml on windows.")
        providerCmd.AddCommand(providerListCmd)
        clientCmd.AddCommand(providerCmd)
        rootCmd.AddCommand(clientCmd)
        // genDocsCmd is a hidden command used as a helper for generating our
        // command line reference documentation.
        genDocsCmd := &cobra.Command{
                Use:    "gendocs <output_dir>",
                Hidden: true,
                Args:   cobra.MaximumNArgs(1),
                RunE: func(cmd *cobra.Command, args []string) error {
                        path := "./docs/cli/"
                        if len(args) > 1 {
                                path = args[1]
                        }
                        err := os.MkdirAll(path, 0775)
                        if err != nil {
                                return err
                        }
                        return doc.GenMarkdownTree(rootCmd, path)
                },
        }
        rootCmd.AddCommand(genDocsCmd)
        err := rootCmd.Execute()
        if err != nil {
                return 1
        }
        return 0
}
func printConfigProblems() {
        problems := files.ConfigProblems().GetProblems()
        if len(problems) > 0 {
                log.Println("Warning: Encountered the following configuration problems:")
                for _, problem := range problems {
                        log.Println(problem.String())
                }
        }
}
// OpenSSH used to impose a 4096-octet limit on the string buffers available to
// the percent_expand function. In October 2019 as part of the 8.1 release,
// that limit was removed. If you exceeded this amount it would fail with
// fatal: percent_expand: string too long
// The following two functions check whether the OpenSSH version on the
// system running the verifier is greater than or equal to 8.1;
// if not then prints a warning
func checkOpenSSHVersion() {
        version := getOpenSSHVersion()
        if version == "" {
                log.Println("Warning: Could not determine OpenSSH version")
                return
        }
        if ok, _ := isOpenSSHVersion8Dot1OrGreater(version); !ok {
                log.Println("Warning: OpenPubkey SSH requires OpenSSH v. 8.1 or greater")
        }
}
// getOpenSSHVersion attempts to get OpenSSH version using multiple fallback methods
func getOpenSSHVersion() string {
        // OS-specific package manager queries
        osType := detectOS()
        log.Printf("Attempting OS-specific version detection for: %s", osType)
        switch osType {
        case OSTypeRHEL:
                // For RedHat-based systems (CentOS, RHEL, Fedora)
                cmd := exec.Command("/bin/sh", "-c", "version=$(/usr/bin/rpm -q --qf \"%{VERSION}\\n\" openssh-server 2>/dev/null | /bin/sed -E 's/^([0-9]+\\.[0-9]+).*/\\1/' | head -1); if [ -n \"$version\" ]; then /bin/echo \"OpenSSH_$version\"; fi")
                if output, err := cmd.CombinedOutput(); err == nil && len(strings.TrimSpace(string(output))) > 0 {
                        return strings.TrimSpace(string(output))
                }
        case OSTypeDebian:
                // For Debian-based systems (Debian, Ubuntu)
                cmd := exec.Command("/bin/sh", "-c", "version=$(/usr/bin/dpkg-query -W -f='${Version}\\n' openssh-server 2>/dev/null | /bin/sed -E 's/^[0-9]*:?([0-9]+\\.[0-9]+).*/\\1/' | head -1); if [ -n \"$version\" ]; then /bin/echo \"OpenSSH_$version\"; fi")
                if output, err := cmd.CombinedOutput(); err == nil && len(strings.TrimSpace(string(output))) > 0 {
                        return strings.TrimSpace(string(output))
                }
        case OSTypeArch:
                // For Arch Linux
                cmd := exec.Command("/bin/sh", "-c", "version=$(/usr/bin/pacman -Qi openssh 2>/dev/null | /usr/bin/awk '/^Version/ {print $3}' | /bin/sed -E 's/^([0-9]+\\.[0-9]+).*/\\1/' | head -1); if [ -n \"$version\" ]; then /bin/echo \"OpenSSH_$version\"; fi")
                if output, err := cmd.CombinedOutput(); err == nil && len(strings.TrimSpace(string(output))) > 0 {
                        return strings.TrimSpace(string(output))
                }
        case OSTypeSUSE:
                // For SUSE-based systems
                cmd := exec.Command("/bin/sh", "-c", "version=$(/usr/bin/rpm -q --qf \"%{VERSION}\\n\" openssh 2>/dev/null | /bin/sed -E 's/^([0-9]+\\.[0-9]+).*/\\1/' | head -1); if [ -n \"$version\" ]; then /bin/echo \"OpenSSH_$version\"; fi")
                if output, err := cmd.CombinedOutput(); err == nil && len(strings.TrimSpace(string(output))) > 0 {
                        return strings.TrimSpace(string(output))
                }
        default:
                log.Printf("Warning: Could not determine OpenSSH version using OS-specific methods for %s", osType)
        }
        // Try ssh -V (works on most systems)
        cmd := exec.Command("ssh", "-V")
        output, err := cmd.CombinedOutput()
        if err == nil && len(strings.TrimSpace(string(output))) > 0 {
                return strings.TrimSpace(string(output))
        }
        log.Println("Warning: Error executing ssh -V:", err)
        // Try sshd -V as fallback
        cmd = exec.Command("sshd", "-V")
        output, err = cmd.CombinedOutput()
        if err == nil && len(strings.TrimSpace(string(output))) > 0 {
                return strings.TrimSpace(string(output))
        }
        log.Println("Warning: Error executing sshd -V:", err)
        return ""
}
func isOpenSSHVersion8Dot1OrGreater(opensshVersion string) (bool, error) {
        // To handle versions like 9.9p1; we only need the initial numeric part for the comparison
        re, err := regexp.Compile(`^(\d+(?:\.\d+)*).*`)
        if err != nil {
                fmt.Println("Error compiling regex:", err)
                return false, err
        }
        opensshVersion = strings.TrimPrefix(
                strings.Split(opensshVersion, ", ")[0],
                "OpenSSH_",
        )
        matches := re.FindStringSubmatch(opensshVersion)
        if len(matches) <= 0 {
                fmt.Println("Invalid OpenSSH version")
                return false, errors.New("invalid OpenSSH version")
        }
        version := matches[1]
        if version >= "8.1" {
                return true, nil
        }
        return false, nil
}
// OSType represents the operating system type
type OSType string
// Operating system constants
const (
        OSTypeGeneric OSType = "generic"
        OSTypeRHEL    OSType = "rhel"
        OSTypeDebian  OSType = "debian"
        OSTypeArch    OSType = "arch"
        OSTypeSUSE    OSType = "suse"
)
// detectOS determines the type of operating system.
func detectOS() OSType {
        // Check for RedHat-based systems
        if _, err := os.Stat("/etc/redhat-release"); err == nil {
                return OSTypeRHEL
        }
        // Check for Debian-based systems
        if _, err := os.Stat("/etc/debian_version"); err == nil {
                return OSTypeDebian
        }
        // Check for Arch Linux
        if _, err := os.Stat("/etc/arch-release"); err == nil {
                return OSTypeArch
        }
        // Check for SUSE Linux
        if _, err := os.Stat("/etc/SuSE-release"); err == nil {
                return OSTypeSUSE
        }
        if _, err := os.Stat("/etc/SUSE-brand"); err == nil {
                return OSTypeSUSE
        }
        // Check for /etc/os-release which exists on most modern Linux systems
        if content, err := os.ReadFile("/etc/os-release"); err == nil {
                contentStr := string(content)
                if strings.Contains(contentStr, "ID=rhel") ||
                        strings.Contains(contentStr, "ID=centos") ||
                        strings.Contains(contentStr, "ID=fedora") {
                        return OSTypeRHEL
                }
                if strings.Contains(contentStr, "ID=debian") ||
                        strings.Contains(contentStr, "ID=ubuntu") {
                        return OSTypeDebian
                }
                if strings.Contains(contentStr, "ID=arch") {
                        return OSTypeArch
                }
                if strings.Contains(contentStr, "ID=sles") ||
                        strings.Contains(contentStr, "ID=opensuse") {
                        return OSTypeSUSE
                }
        }
        // Default to generic, if no specific OS type is detected.
        return OSTypeGeneric
}
		
		// Copyright 2025 OpenPubkey
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package policy
import (
        "encoding/json"
        "errors"
        "fmt"
        "log"
        "os"
        "strings"
        "github.com/openpubkey/openpubkey/pktoken"
        "github.com/openpubkey/opkssh/policy/plugins"
        "golang.org/x/exp/slices"
)
const (
        OIDC_CLAIMS         = "oidc:"
        OIDC_WILDCARD_EMAIL = "oidc-match-end:email:"
)
// DenyList represents the DenyLists in the server config
type DenyList struct {
        Emails []string
        Users  []string
}
// Enforcer evaluates opkssh policy to determine if the desired principal is
// permitted
type Enforcer struct {
        PolicyLoader Loader
}
// type for Identity Token checkedClaims
type checkedClaims struct {
        Email       string              `json:"email"`
        Sub         string              `json:"sub"`
        ExtraClaims map[string][]string `json:"-"`
}
func (s *checkedClaims) UnmarshalJSON(data []byte) error {
        // Avoid infinite recursion
        type checkedClaimsAlias checkedClaims
        var a checkedClaimsAlias
        // Unmarshal the required claims
        if err := json.Unmarshal(data, &a); err != nil {
                return err
        }
        *s = checkedClaims(a)
        // Unmarshal everything else
        var schema map[string]interface{}
        err := json.Unmarshal([]byte(data), &schema)
        if err != nil {
                return err
        }
        var raw map[string]any
        if err := json.Unmarshal(data, &raw); err != nil {
                return err
        }
        s.ExtraClaims = make(map[string][]string, len(raw))
        for k, v := range raw {
                switch t := v.(type) {
                case string:
                        s.ExtraClaims[k] = []string{t}
                case []any:
                        // Turn all elements in a list into a string
                        out := make([]string, 0, len(t))
                        for _, e := range t {
                                if s, ok := e.(string); ok {
                                        out = append(out, s)
                                } else {
                                        out = append(out, fmt.Sprint(e))
                                }
                        }
                        s.ExtraClaims[k] = out
                default:
                        // Turn numbers/bools etc into strings
                        s.ExtraClaims[k] = []string{fmt.Sprint(t)}
                }
        }
        return nil
}
// The default location for policy plugins
const pluginPolicyDir = "/etc/opk/policy.d"
// EscapedSplit splits a string by a separator while ignoring the separator in quoted sections.
// This is useful for strings that may contain the separator character as part of the string
// and not as a delimiter.
func EscapedSplit(s string, sep rune) []string {
        quoted := false
        a := strings.FieldsFunc(s, func(r rune) bool {
                if r == '"' {
                        quoted = !quoted
                }
                return !quoted && r == sep
        })
        return a
}
// Validates that the server defined identity attribute matches the
// respective claim from the identity token
func validateClaim(claims *checkedClaims, user *User) bool {
        // Should we match on the email claim?
        if strings.HasPrefix(claims.Email, OIDC_WILDCARD_EMAIL) {
                return false
        }
        // Should we match on an oidc claim?
        if strings.HasPrefix(user.IdentityAttribute, OIDC_CLAIMS) {
                oidcGroupSections := EscapedSplit(user.IdentityAttribute, ':')
                oidcGroupsName := strings.Trim(oidcGroupSections[1], "\"")
                return slices.Contains(
                        claims.ExtraClaims[oidcGroupsName],
                        oidcGroupSections[len(oidcGroupSections)-1],
                )
        }
        // Should we match on the email wildcard claim?
        wildCardEmailMatch := false
        if strings.HasPrefix(user.IdentityAttribute, OIDC_WILDCARD_EMAIL) {
                if strings.HasSuffix(strings.ToLower(claims.Email), strings.ToLower(user.IdentityAttribute[len(OIDC_WILDCARD_EMAIL):len(user.IdentityAttribute)])) {
                        wildCardEmailMatch = true
                }
        }
        // email should be a case-insensitive check
        // sub should be a case-sensitive check
        return wildCardEmailMatch || strings.EqualFold(claims.Email, user.IdentityAttribute) || string(claims.Sub) == user.IdentityAttribute
}
// CheckPolicy loads opkssh policy and checks to see if there is a policy
// permitting access to principalDesired for the user identified by the PKT's
// email claim. Returns nil if access is granted. Otherwise, an error is
// returned.
//
// It is security critical to verify the pkt first before calling this function.
// This is because if this function is called first, a timing channel exists which
// allows an attacker check what identities and principals are allowed by the policy.F
func (p *Enforcer) CheckPolicy(principalDesired string, pkt *pktoken.PKToken, userInfoJson string, sshCert string, keyType string, denyList DenyList) error {
        var claims checkedClaims
        if err := json.Unmarshal(pkt.Payload, &claims); err != nil {
                return fmt.Errorf("error unmarshalling pk token payload: %w", err)
        }
        issuer, err := pkt.Issuer()
        if err != nil {
                return fmt.Errorf("error getting issuer from pk token: %w", err)
        }
        // Enforce deny list first
        for _, email := range denyList.Emails {
                if strings.EqualFold(claims.Email, email) {
                        return fmt.Errorf("denied email %s", email)
                }
        }
        for _, user := range denyList.Users {
                if strings.EqualFold(principalDesired, user) {
                        return fmt.Errorf("denied user %s", user)
                }
        }
        pluginPolicy := plugins.NewPolicyPluginEnforcer()
        results, err := pluginPolicy.CheckPolicies(pluginPolicyDir, pkt, userInfoJson, principalDesired, sshCert, keyType)
        if err != nil {
                if errors.Is(err, os.ErrNotExist) {
                        log.Println("Skipping policy plugins: no plugins found at " + pluginPolicyDir)
                } else {
                        log.Printf("Error checking policy plugins: %v \n", err)
                }
                // Despite the error, we don't fail here because we still want to check
                // the standard policy below. Policy plugins can only expand the set of
                // allow set, not shrink it.
        } else {
                for _, result := range results {
                        commandRunStr := strings.Join(result.CommandRun, " ")
                        log.Printf("Policy plugin result, path: (%s), allowed: (%t), error: (%v), command_run: (%s), policyOutput: (%s)\n", result.Path, result.Allowed, result.Error, commandRunStr, result.PolicyOutput)
                }
                if results.Allowed() {
                        log.Printf("Access granted by policy plugin\n")
                        return nil
                }
        }
        policy, source, err := p.PolicyLoader.Load()
        if err != nil {
                return fmt.Errorf("error loading policy: %w", err)
        }
        var userInfoClaims *checkedClaims
        if userInfoJson != "" {
                userInfoClaims = new(checkedClaims)
                if err := json.Unmarshal([]byte(userInfoJson), userInfoClaims); err != nil {
                        return fmt.Errorf("error unmarshalling claims from userinfo endpoint: %w", err)
                }
        }
        for _, user := range policy.Users {
                // The underlying library checks idT.sub == userInfo.sub when we call the userinfo endpoint.
                // We want to be extra sure so we also check it here as well.
                if userInfoClaims != nil && claims.Sub != userInfoClaims.Sub {
                        return fmt.Errorf("userInfo sub claim (%s) does not match user policy sub claim (%s)", userInfoClaims.Sub, claims.Sub)
                }
                if issuer != user.Issuer {
                        continue
                }
                // if they are, then check if the desired principal is allowed
                if !slices.Contains(user.Principals, principalDesired) {
                        continue
                }
                // check each entry to see if the user in the checkedClaims is included
                if validateClaim(&claims, &user) {
                        // access granted
                        return nil
                }
                // check each entry to see if the user matches the userInfoClaims
                if userInfoClaims != nil && validateClaim(userInfoClaims, &user) {
                        // access granted
                        return nil
                }
        }
        return fmt.Errorf("no policy to allow %s with (issuer=%s) to assume %s, check policy config at %s", claims.Email, issuer, principalDesired, source.Source())
}
		
		// Copyright 2025 OpenPubkey
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package files
import (
        "fmt"
        "strings"
        "sync"
)
type ConfigProblem struct {
        Filepath            string
        OffendingLine       string
        OffendingLineNumber int
        ErrorMessage        string
        Source              string
}
func (e ConfigProblem) String() string {
        return "encountered error: " + e.ErrorMessage + ", reading " + e.OffendingLine + " in " + e.Filepath + " at line " + fmt.Sprint(e.OffendingLineNumber)
}
type ConfigLog struct {
        log      []ConfigProblem
        logMutex sync.Mutex
}
func (c *ConfigLog) RecordProblem(entry ConfigProblem) {
        c.logMutex.Lock()
        defer c.logMutex.Unlock()
        c.log = append(c.log, entry)
}
func (c *ConfigLog) GetProblems() []ConfigProblem {
        c.logMutex.Lock()
        defer c.logMutex.Unlock()
        logCopy := make([]ConfigProblem, len(c.log))
        copy(logCopy, c.log)
        return logCopy
}
func (c *ConfigLog) NoProblems() bool {
        c.logMutex.Lock()
        defer c.logMutex.Unlock()
        return len(c.log) == 0
}
func (c *ConfigLog) String() string {
        // No mutex needed since GetLogs handles the mutex
        logs := c.GetProblems()
        logsStrings := []string{}
        for _, log := range logs {
                logsStrings = append(logsStrings, log.String())
        }
        return strings.Join(logsStrings, "\n")
}
func (c *ConfigLog) Clear() {
        c.logMutex.Lock()
        defer c.logMutex.Unlock()
        c.log = []ConfigProblem{}
}
var (
        singleton *ConfigLog
        once      sync.Once
)
func ConfigProblems() *ConfigLog {
        once.Do(func() {
                singleton = &ConfigLog{
                        log:      []ConfigProblem{},
                        logMutex: sync.Mutex{},
                }
        })
        return singleton
}
		
		// Copyright 2025 OpenPubkey
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package files
import (
        "fmt"
        "io/fs"
        "path/filepath"
        "github.com/spf13/afero"
)
// UserPolicyLoader contains methods to read/write the opkssh policy file from/to an
// arbitrary filesystem. All methods that read policy from the filesystem fail
// and return an error immediately if the permission bits are invalid.
type FileLoader struct {
        Fs           afero.Fs
        RequiredPerm fs.FileMode
}
// CreateIfDoesNotExist creates a file at the given path if it does not exist.
func (l FileLoader) CreateIfDoesNotExist(path string) error {
        exists, err := afero.Exists(l.Fs, path)
        if err != nil {
                return err
        }
        if !exists {
                dirPath := filepath.Dir(path)
                if err := l.Fs.MkdirAll(dirPath, 0750); err != nil {
                        return fmt.Errorf("failed to create directory: %w", err)
                }
                file, err := l.Fs.Create(path)
                if err != nil {
                        return fmt.Errorf("failed to create file: %w", err)
                }
                file.Close()
                if err := l.Fs.Chmod(path, l.RequiredPerm); err != nil {
                        return fmt.Errorf("failed to set file permissions: %w", err)
                }
        }
        return nil
}
// LoadFileAtPath validates that the file at path exists, can be read
// by the current process, and has the correct permission bits set. Parses the
// contents and returns the bytes if file permissions are valid and
// reading is successful; otherwise returns an error.
func (l *FileLoader) LoadFileAtPath(path string) ([]byte, error) {
        // Check if file exists and we can access it
        if _, err := l.Fs.Stat(path); err != nil {
                return nil, fmt.Errorf("failed to describe the file at path: %w", err)
        }
        // Validate that file has correct permission bits set
        if err := NewPermsChecker(l.Fs).CheckPerm(path, []fs.FileMode{l.RequiredPerm}, "", ""); err != nil {
                return nil, fmt.Errorf("policy file has insecure permissions: %w", err)
        }
        // Read file contents
        afs := &afero.Afero{Fs: l.Fs}
        content, err := afs.ReadFile(path)
        if err != nil {
                return nil, err
        }
        return content, nil
}
// Dump writes the bytes in fileBytes to the filepath
func (l *FileLoader) Dump(fileBytes []byte, path string) error {
        // Write to disk
        if err := afero.WriteFile(l.Fs, path, fileBytes, l.RequiredPerm); err != nil {
                return err
        }
        return nil
}
		
		// Copyright 2025 OpenPubkey
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package files
import (
        "fmt"
        "io/fs"
        "os/exec"
        "strings"
        "github.com/spf13/afero"
)
// ModeSystemPerms is the expected permission bits that should be set for opkssh
// system policy files (`/etc/opk/auth_id`, `/etc/opk/providers`). This mode means
// that only the owner of the file can write/read to the file, but the group which
// should be opksshuser can read the file.
const ModeSystemPerms = fs.FileMode(0640)
// ModeHomePerms is the expected permission bits that should be set for opkssh
// user home policy files `~/.opk/auth_id`.
const ModeHomePerms = fs.FileMode(0600)
// PermsChecker contains methods to check the ownership, group
// and file permissions of a file on a Unix-like system.
type PermsChecker struct {
        Fs        afero.Fs
        CmdRunner func(string, ...string) ([]byte, error)
}
func NewPermsChecker(fs afero.Fs) *PermsChecker {
        return &PermsChecker{Fs: fs, CmdRunner: ExecCmd}
}
// CheckPerm checks the file at the given path if it has the desired permissions.
// The argument requirePerm is a list to enable the caller to specify multiple
// permissions only one of which needs to match the permissions on the file.
// If the requiredOwner or requiredGroup are not empty then the function will also
// that the owner and group of the file match the requiredOwner and requiredGroup
// specified and fail if they do not.
func (u *PermsChecker) CheckPerm(path string, requirePerm []fs.FileMode, requiredOwner string, requiredGroup string) error {
        fileInfo, err := u.Fs.Stat(path)
        if err != nil {
                return fmt.Errorf("failed to describe the file at path: %w", err)
        }
        mode := fileInfo.Mode()
        // if the requiredOwner or requiredGroup are specified then run stat and check if they match
        if requiredOwner != "" || requiredGroup != "" {
                statOutput, err := u.CmdRunner("stat", "-c", "%U %G", path)
                if err != nil {
                        return fmt.Errorf("failed to run stat: %w", err)
                }
                statOutputSplit := strings.Split(strings.TrimSpace(string(statOutput)), " ")
                statOwner := statOutputSplit[0]
                statGroup := statOutputSplit[1]
                if len(statOutputSplit) != 2 {
                        return fmt.Errorf("expected stat command to return 2 values got %d", len(statOutputSplit))
                }
                if requiredOwner != "" {
                        if requiredOwner != statOwner {
                                return fmt.Errorf("expected owner (%s), got (%s)", requiredOwner, statOwner)
                        }
                }
                if requiredGroup != "" {
                        if requiredGroup != statGroup {
                                return fmt.Errorf("expected group (%s), got (%s)", requiredGroup, statGroup)
                        }
                }
        }
        permMatch := false
        requiredPermString := []string{}
        for _, p := range requirePerm {
                requiredPermString = append(requiredPermString, fmt.Sprintf("%o", p.Perm()))
                if mode.Perm() == p {
                        permMatch = true
                }
        }
        if !permMatch {
                return fmt.Errorf("expected one of the following permissions [%s], got (%o)", strings.Join(requiredPermString, ", "), mode.Perm())
        }
        return nil
}
func ExecCmd(name string, arg ...string) ([]byte, error) {
        cmd := exec.Command(name, arg...)
        return cmd.CombinedOutput()
}
		
		// Copyright 2025 OpenPubkey
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package files
import (
        "log"
        "strings"
        "github.com/kballard/go-shellquote"
)
type Table struct {
        rows [][]string
}
func NewTable(content []byte) *Table {
        table := [][]string{}
        rows := strings.Split(string(content), "\n")
        for _, row := range rows {
                row := CleanRow(row)
                if row == "" {
                        continue
                }
                columns, err := shellquote.Split(row)
                if err != nil {
                        log.Printf("Unable to parse: %s. (%s), skipping...\n", row, err)
                        continue
                }
                table = append(table, columns)
        }
        return &Table{rows: table}
}
func CleanRow(row string) string {
        // Remove comments
        rowFixed := strings.Split(row, "#")[0]
        // Skip empty rows
        rowFixed = strings.TrimSpace(rowFixed)
        return rowFixed
}
func (t *Table) AddRow(row ...string) {
        t.rows = append(t.rows, row)
}
func (t Table) ToString() string {
        var sb strings.Builder
        for _, row := range t.rows {
                sb.WriteString(shellquote.Join(row...) + "\n")
        }
        return sb.String()
}
func (t Table) ToBytes() []byte {
        return []byte(t.ToString())
}
func (t Table) GetRows() [][]string {
        return t.rows
}
		
		// Copyright 2025 OpenPubkey
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package policy
import (
        "errors"
        "fmt"
        "log"
        "os"
        "os/exec"
        "strings"
)
var _ Loader = &MultiPolicyLoader{
        LoaderScript: ReadWithSudoScript,
}
// FileSource implements policy.Source by returning a string that is expected to
// be a filepath
type FileSource string
func (s FileSource) Source() string {
        return string(s)
}
func NewMultiPolicyLoader(username string, loader OptionalLoader) *MultiPolicyLoader {
        return &MultiPolicyLoader{
                HomePolicyLoader:   NewHomePolicyLoader(),
                SystemPolicyLoader: NewSystemPolicyLoader(),
                LoaderScript:       loader,
                Username:           username,
        }
}
// MultiPolicyLoader implements policy.Loader by reading both the system default
// policy (root policy) and user policy (~/.opk/auth_id where ~ maps to
// Username's home directory)
type MultiPolicyLoader struct {
        HomePolicyLoader   *HomePolicyLoader
        SystemPolicyLoader *SystemPolicyLoader
        LoaderScript       OptionalLoader
        Username           string
}
func (l *MultiPolicyLoader) Load() (*Policy, Source, error) {
        policy := new(Policy)
        // Try to load the root policy
        rootPolicy, _, rootPolicyErr := l.SystemPolicyLoader.LoadSystemPolicy()
        if rootPolicyErr != nil {
                log.Println("warning: failed to load system default policy:", rootPolicyErr)
        }
        // Try to load the user policy
        userPolicy, userPolicyFilePath, userPolicyErr := l.HomePolicyLoader.LoadHomePolicy(l.Username, true, l.LoaderScript)
        if userPolicyErr != nil {
                log.Println("warning: failed to load user policy:", userPolicyErr)
        }
        // Log warning if no error loading, but userPolicy is empty meaning that
        // there are no valid entries
        if userPolicyErr == nil && len(userPolicy.Users) == 0 {
                log.Printf("warning: user policy %s has no valid user entries; an entry is considered valid if it gives %s access.", userPolicyFilePath, l.Username)
        }
        // Failed to read both policies. Return multi-error
        if rootPolicy == nil && userPolicy == nil {
                return nil, EmptySource{}, errors.Join(rootPolicyErr, userPolicyErr)
        }
        // TODO-Yuval: Optimize by merging duplicate entries instead of blindly
        // appending
        readPaths := []string{}
        if rootPolicy != nil {
                policy.Users = append(policy.Users, rootPolicy.Users...)
                readPaths = append(readPaths, SystemDefaultPolicyPath)
        }
        if userPolicy != nil {
                policy.Users = append(policy.Users, userPolicy.Users...)
                readPaths = append(readPaths, userPolicyFilePath)
        }
        return policy, FileSource(strings.Join(readPaths, ", ")), nil
}
// ReadWithSudoScript specifies additional way of loading the policy in the
// user's home directory (`~/.opk/auth_id`). This is needed when the
// AuthorizedKeysCommand user does not have privileges to transverse the user's
// home directory. Instead we call run a command which uses special
// sudoers permissions to read the policy file.
//
// Doing this is more secure than simply giving opkssh sudoer access because
// if there was an RCE in opkssh could be triggered an SSH request via
// AuthorizedKeysCommand, the new opkssh process we use to perform the read
// would not be compromised. Thus, the compromised opkssh process could not assume
// full root privileges.
func ReadWithSudoScript(h *HomePolicyLoader, username string) ([]byte, error) {
        // opkssh readhome ensures the file is not a symlink and has the permissions/ownership.
        // The default path is /usr/local/bin/opkssh
        opkBin, err := os.Executable()
        if err != nil {
                return nil, fmt.Errorf("error getting opkssh executable path: %w", err)
        }
        cmd := exec.Command("sudo", "-n", opkBin, "readhome", username)
        homePolicyFileBytes, err := cmd.CombinedOutput()
        if err != nil {
                return nil, fmt.Errorf("error reading %s home policy using command %v got output %v and err %v", username, cmd, string(homePolicyFileBytes), err)
        }
        return homePolicyFileBytes, nil
}
		
		// Copyright 2025 OpenPubkey
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package plugins
import (
        "bytes"
        "encoding/base64"
        "fmt"
        "io/fs"
        "os"
        "os/exec"
        "path/filepath"
        "strings"
        "github.com/kballard/go-shellquote"
        "github.com/openpubkey/openpubkey/pktoken"
        "github.com/openpubkey/opkssh/policy/files"
        "github.com/spf13/afero"
        "gopkg.in/yaml.v3"
)
const requiredPolicyPerms = fs.FileMode(0640)
var requiredPolicyDirPerms = []fs.FileMode{fs.FileMode(0700), fs.FileMode(0750), fs.FileMode(0755)}
var requiredPolicyCmdPerms = []fs.FileMode{fs.FileMode(0555), fs.FileMode(0755)}
type PluginResult struct {
        Path         string
        PluginConfig PluginConfig
        Error        error
        CommandRun   []string
        PolicyOutput string
        Allowed      bool
}
type PluginResults []*PluginResult
func (r PluginResults) Errors() (errs []error) {
        for _, pluginResult := range r {
                if pluginResult.Error != nil {
                        errs = append(errs, pluginResult.Error)
                }
        }
        return errs
}
func (r PluginResults) Allowed() bool {
        for _, pluginResult := range r {
                if pluginResult.Allowed {
                        if pluginResult.PolicyOutput != "allow" {
                                // This uses a double-entry bookkeeping approach to catch
                                // security critical bugs.
                                // Allowed is only set to true if the policy plugin command
                                // returns exactly "allow" and we set PolicyOutput to the
                                // value that the policy plugin command returned. Thus if
                                // (PolicyOutput != "allow") AND (Allowed == true) something
                                // went epically wrong and we should panic.
                                // This should never happen.
                                panic(fmt.Sprintf("Danger!!! Policy plugin command (%s) returned 'allow' but the plugin command did not approve. If you encounter this, report this as a vulnerability.", pluginResult.Path))
                        }
                        return true
                }
        }
        return false
}
type CmdExecutor func(name string, arg ...string) ([]byte, error)
func DefaultCmdExecutor(name string, arg ...string) ([]byte, error) {
        return exec.Command(name, arg...).CombinedOutput()
}
type PolicyPluginEnforcer struct {
        Fs          afero.Fs
        cmdExecutor CmdExecutor // This lets us mock command exec in unit tests
        permChecker files.PermsChecker
}
func NewPolicyPluginEnforcer() *PolicyPluginEnforcer {
        fs := afero.NewOsFs()
        return &PolicyPluginEnforcer{
                Fs:          fs,
                cmdExecutor: DefaultCmdExecutor,
                permChecker: files.PermsChecker{
                        Fs:        fs,
                        CmdRunner: files.ExecCmd,
                },
        }
}
// loadPlugins loads the plugin config files from the given directory.
func (p *PolicyPluginEnforcer) loadPlugins(dir string) (pluginResults PluginResults, err error) {
        // Ensure the /opk/ssh/policy.d can only be written by root
        if err := p.permChecker.CheckPerm(dir, requiredPolicyDirPerms, "root", ""); err != nil {
                return nil, fmt.Errorf("policy plugin directory (%s) has insecure permissions: %w", dir, err)
        }
        filesFound, err := afero.ReadDir(p.Fs, dir)
        if err != nil {
                return nil, err
        }
        for _, entry := range filesFound {
                path := filepath.Join(dir, entry.Name())
                info, err := p.Fs.Stat(path)
                if err != nil {
                        return nil, err
                }
                if !info.IsDir() && strings.HasSuffix(info.Name(), ".yml") {
                        pluginResult := &PluginResult{}
                        pluginResults = append(pluginResults, pluginResult)
                        pluginResult.Path = path
                        if err := p.permChecker.CheckPerm(path, []fs.FileMode{requiredPolicyPerms}, "root", ""); err != nil {
                                pluginResult.Error = fmt.Errorf("policy plugin config file (%s) has insecure permissions: %w", path, err)
                                continue
                        }
                        file, err := afero.ReadFile(p.Fs, path)
                        if err != nil {
                                pluginResult.Error = fmt.Errorf("failed to read policy plugin config at (%s): %w", path, err)
                                continue
                        }
                        var cmd PluginConfig
                        if err := yaml.Unmarshal(file, &cmd); err != nil {
                                pluginResult.Error = fmt.Errorf("failed to parse YAML in policy plugin config at (%s): %w", path, err)
                                continue
                        }
                        if cmd.Name == "" {
                                pluginResult.Error = fmt.Errorf("policy plugin config missing required field 'name' in policy plugin config at (%s)", path)
                                continue
                        }
                        if cmd.Command == "" {
                                pluginResult.Error = fmt.Errorf("policy plugin config missing required field 'command' in policy plugin config at (%s): ", path)
                                continue
                        }
                        pluginResult.PluginConfig = cmd
                }
        }
        return pluginResults, nil
}
// CheckPolicies loads the policies plugin configs in the directory dir
// and then runs the policy command specified in which policy plugin config
// to determine if the user is allowed to assume access as the given principal.
// It returns PluginResults for each plugin configs found in the policy
// plugin directory.
//
// Run PluginResults.Allowed() to determine if the user is allowed to
// assume access.
//
// CheckPolicies does not short circuit if a policy returns allow. This is to
// enable admins to do a test rollout of a new policy plugin without needing to
// disable the old policy plugin until they are sure the new policy plugin is
// working correctly.
func (p *PolicyPluginEnforcer) CheckPolicies(dir string, pkt *pktoken.PKToken, userInfoJson string, principal string, sshCert string, keyType string) (PluginResults, error) {
        tokens, err := PopulatePluginEnvVars(pkt, userInfoJson, principal, sshCert, keyType)
        if err != nil {
                return nil, err
        }
        return p.checkPolicies(dir, tokens)
}
func (p *PolicyPluginEnforcer) checkPolicies(dir string, tokens map[string]string) (PluginResults, error) {
        pluginResults, err := p.loadPlugins(dir)
        if err != nil {
                return nil, fmt.Errorf("failed to load policy commands: %w", err)
        }
        for _, pluginResult := range pluginResults {
                // Only run the command in the plugin config if there was no error loading the plugin config
                if pluginResult.Error == nil {
                        commandRun, output, err := p.executePolicyCommand(pluginResult.PluginConfig, tokens)
                        output = bytes.TrimSpace(output)
                        pluginResult.Error = err
                        pluginResult.PolicyOutput = string(output)
                        pluginResult.CommandRun = commandRun
                        if err != nil {
                                pluginResult.Error = fmt.Errorf("failed to run policy command %s got error (%w)", pluginResult.PluginConfig.Command, err)
                                continue
                        } else if string(output) != "allow" {
                                pluginResult.Allowed = false
                        } else {
                                pluginResult.Allowed = true
                        }
                }
        }
        return pluginResults, nil
}
// executePolicyCommand executes the policy command with the provided tokens.
func (p *PolicyPluginEnforcer) executePolicyCommand(config PluginConfig, inputEnvVars map[string]string) ([]string, []byte, error) {
        // Add PluginConfig to the tokens map for expansion
        configJson, err := yaml.Marshal(config)
        if err != nil {
                return nil, nil, fmt.Errorf("failed to marshal config to JSON: %w", err)
        }
        inputEnvVars["OPKSSH_PLUGIN_CONFIG"] = base64.StdEncoding.EncodeToString(configJson)
        // Ensure we don't use any environment variables as an input to
        // the policy plugin command that this process inherited. We only
        // want to pass values we set ourselves.
        for _, envVar := range os.Environ() {
                if strings.HasPrefix(envVar, "OPKSSH_PLUGIN_") {
                        os.Unsetenv(strings.Split(envVar, "=")[0])
                }
        }
        for envK, envV := range inputEnvVars {
                if err := os.Setenv(envK, envV); err != nil {
                        return nil, nil, fmt.Errorf("failed to set environment variable %s: %w", envK, err)
                }
        }
        command, err := shellquote.Split(config.Command)
        if err != nil {
                return nil, nil, err
        }
        if err := p.permChecker.CheckPerm(command[0], requiredPolicyCmdPerms, "root", ""); err != nil {
                if strings.Contains(err.Error(), "file does not exist") {
                        return nil, nil, err
                } else {
                        return nil, nil, fmt.Errorf("policy plugin command (%s) has insecure permissions: %w", command[0], err)
                }
        }
        output, err := p.cmdExecutor(command[0], command[1:]...)
        return command, output, err
}
// b64 is a simple helper function to base64 encode a string.
func b64(s string) string {
        return base64.StdEncoding.EncodeToString([]byte(s))
}
		
		// Copyright 2025 OpenPubkey
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package plugins
import (
        "encoding/base64"
        "encoding/json"
        "fmt"
        "strings"
        "github.com/openpubkey/openpubkey/pktoken"
)
func PopulatePluginEnvVars(pkt *pktoken.PKToken, userInfoJson string, principal string, sshCert string, keyType string) (map[string]string, error) {
        pktCom, err := pkt.Compact()
        if err != nil {
                return nil, err
        }
        cicClaims, err := pkt.GetCicValues()
        if err != nil {
                return nil, err
        }
        upkJwk := cicClaims.PublicKey()
        upkJson, err := json.Marshal(upkJwk)
        if err != nil {
                return nil, err
        }
        upkB64 := base64.StdEncoding.EncodeToString(upkJson)
        type Claims struct {
                Issuer        string    `json:"iss"`
                Sub           string    `json:"sub"`
                Email         string    `json:"email"`
                EmailVerified *bool     `json:"email_verified"`
                Aud           Audience  `json:"aud"`
                Exp           *int64    `json:"exp"`
                Nbf           *int64    `json:"nbf"`
                Iat           *int64    `json:"iat"`
                Jti           string    `json:"jti"`
                Groups        *[]string `json:"groups"`
        }
        var claims Claims
        if err := json.Unmarshal(pkt.Payload, &claims); err != nil {
                return nil, fmt.Errorf("error unmarshalling pk token payload: %w", err)
        }
        groupsStr := ""
        if claims.Groups != nil {
                groupsStr = fmt.Sprintf(`["%s"]`, strings.Join(*claims.Groups, `","`))
        }
        emailVerifiedStr := ""
        if claims.EmailVerified != nil {
                emailVerifiedStr = fmt.Sprintf("%t", *claims.EmailVerified)
        }
        expStr := ""
        if claims.Exp != nil {
                expStr = fmt.Sprintf("%d", *claims.Exp)
        }
        nbfStr := ""
        if claims.Nbf != nil {
                nbfStr = fmt.Sprintf("%d", *claims.Nbf)
        }
        iatStr := ""
        if claims.Iat != nil {
                iatStr = fmt.Sprintf("%d", *claims.Iat)
        }
        tokens := map[string]string{
                "OPKSSH_PLUGIN_U": principal,
                "OPKSSH_PLUGIN_K": sshCert,
                "OPKSSH_PLUGIN_T": keyType,
                "OPKSSH_PLUGIN_ISS":            claims.Issuer,
                "OPKSSH_PLUGIN_SUB":            claims.Sub,
                "OPKSSH_PLUGIN_EMAIL":          claims.Email,
                "OPKSSH_PLUGIN_EMAIL_VERIFIED": emailVerifiedStr,
                "OPKSSH_PLUGIN_AUD":            string(claims.Aud),
                "OPKSSH_PLUGIN_EXP":            expStr,
                "OPKSSH_PLUGIN_NBF":            nbfStr,
                "OPKSSH_PLUGIN_IAT":            iatStr,
                "OPKSSH_PLUGIN_JTI":            claims.Jti,
                "OPKSSH_PLUGIN_GROUPS":         groupsStr,
                "OPKSSH_PLUGIN_PAYLOAD":  string(b64(string(pkt.Payload))), // base64-encoded ID Token payload
                "OPKSSH_PLUGIN_UPK":      string(upkB64),                   // base64-encoded JWK of the user's public key
                "OPKSSH_PLUGIN_PKT":      string(pktCom),                   // compact-encoded PK Token
                "OPKSSH_PLUGIN_IDT":      string(pkt.OpToken),              // base64-encoded ID Token
                "OPKSSH_PLUGIN_USERINFO": userInfoJson,                     // what the userinfo endpoint returned if an access token was supplied (by default this the empty string)
        }
        return tokens, nil
}
type Audience string
func (a *Audience) UnmarshalJSON(data []byte) error {
        var multi []string
        if err := json.Unmarshal(data, &multi); err == nil {
                *a = Audience(`["` + strings.Join(multi, `","`) + `"]`)
                return nil
        }
        var single string
        if err := json.Unmarshal(data, &single); err == nil {
                *a = Audience(single)
                return nil
        } else {
                return err
        }
}
		
		// Copyright 2025 OpenPubkey
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package policy
import (
        "fmt"
        "log"
        "strings"
        "github.com/openpubkey/opkssh/policy/files"
)
// User is an opkssh policy user entry
type User struct {
        // IdentityAttribute is a string that is either structured or unstructured.
        // Structured: <IdentityProtocolMatching>:<Attribute>:<Value>
        // E.g. `oidc:groups:ssh-users`
        // Using the structured identifier allows the capability of constructing
        // complex user matchers.
        //
        // Unstructured:
        // This is older version that only works with OIDC Identity Tokens, with
        // the claim being `email` or `sub`. The expected value is to be the user's
        // email or the user's subscriber ID. The expected value used when comparing
        // against an id_token's email claim Subscriber ID is a unique identifier
        // for the user at the OpenID Provider
        IdentityAttribute string
        // Principals is a list of allowed principals
        Principals []string
        // Sub        string
        Issuer string
}
// Policy represents an opkssh policy
type Policy struct {
        // Users is a list of all user entries in the policy
        Users []User
}
// FromTable decodes whitespace delimited input into policy.Policy
func FromTable(input []byte, path string) *Policy {
        table := files.NewTable(input)
        policy := &Policy{}
        for i, row := range table.GetRows() {
                // Error should not break everyone's ability to login, skip those rows
                if len(row) != 3 {
                        configProblem := files.ConfigProblem{
                                Filepath:            path,
                                OffendingLine:       strings.Join(row, " "),
                                OffendingLineNumber: i,
                                ErrorMessage:        fmt.Sprintf("wrong number of arguments (expected=3, got=%d)", len(row)),
                                Source:              "user policy file",
                        }
                        files.ConfigProblems().RecordProblem(configProblem)
                        continue
                }
                user := User{
                        Principals:        []string{row[0]},
                        IdentityAttribute: row[1],
                        Issuer:            row[2],
                }
                policy.Users = append(policy.Users, user)
        }
        return policy
}
// AddAllowedPrincipal adds a new allowed principal to the user whose email is
// equal to userEmail. If no user can be found with the email userEmail, then a
// new user entry is added with an initial allowed principals list containing
// principal. No changes are made if the principal is already allowed for this
// user.
func (p *Policy) AddAllowedPrincipal(principal string, userEmail string, issuer string) {
        var firstMatchingEntry *User // First entry that matches on userEmail AND issuer
        for i := range p.Users {
                // Search to see if the current user already has an entry that matches on userEmail AND issuer
                user := &p.Users[i]
                if user.IdentityAttribute == userEmail && user.Issuer == issuer {
                        if firstMatchingEntry == nil {
                                firstMatchingEntry = user
                        }
                        for _, p := range user.Principals {
                                if p == principal {
                                        // If we find an entry that matches on userEmail AND issuer AND principal, nothing to add
                                        log.Printf("User with email %s already has access under the principal %s, skipping...\n", userEmail, principal)
                                        return // return early, attempting to add a duplicate policy, a policy which already exists
                                }
                        }
                }
        }
        if firstMatchingEntry != nil {
                // If we are here, then we found an entry where userEmail and user.Issuer match, but not the principal.
                // Add the principal to that entries list of principals
                firstMatchingEntry.Principals = append(firstMatchingEntry.Principals, principal)
                log.Printf("Successfully added user with email %s with principal %s to the policy file\n", userEmail, principal)
                return // Done, we added the principal to the existing user
        }
        // If we are here, then there is no row in the policy file that matches
        // the userEmail and issuer.
        newUser := User{
                IdentityAttribute: userEmail,
                Principals:        []string{principal},
                Issuer:            issuer,
        }
        // Add the new user to the list of users in the policy
        p.Users = append(p.Users, newUser)
        log.Printf("Successfully added user with email %s with principal %s to the policy file\n", userEmail, principal)
}
// ToTable encodes the policy into a whitespace delimited table
func (p *Policy) ToTable() ([]byte, error) {
        table := files.Table{}
        for _, user := range p.Users {
                for _, principal := range user.Principals {
                        table.AddRow(principal, user.IdentityAttribute, user.Issuer)
                }
        }
        return table.ToBytes(), nil
}
// Source declares the minimal interface to describe the source of a fetched
// opkssh policy (i.e. where the policy is retrieved from)
type Source interface {
        // Source returns a string describing the source of an opkssh policy. The
        // returned value is empty if there is no information about its source
        Source() string
}
var _ Source = &EmptySource{}
// EmptySource implements policy.Source and returns an empty string as the
// source
type EmptySource struct{}
func (EmptySource) Source() string { return "" }
// Loader declares the minimal interface to retrieve an opkssh policy from an
// arbitrary source
type Loader interface {
        // Load fetches an opkssh policy and returns information describing its
        // source. If an error occurs, all return values are nil except the error
        // value
        Load() (*Policy, Source, error)
}
		
		// Copyright 2025 OpenPubkey
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package policy
import (
        "fmt"
        "os/user"
        "path"
        "path/filepath"
        "github.com/openpubkey/opkssh/policy/files"
        "github.com/spf13/afero"
        "golang.org/x/exp/slices"
)
// SystemDefaultPolicyPath is the default filepath where opkssh policy is
// defined
var SystemDefaultPolicyPath = filepath.FromSlash("/etc/opk/auth_id")
// UserLookup defines the minimal interface to lookup users on the current
// system
type UserLookup interface {
        Lookup(username string) (*user.User, error)
}
// OsUserLookup implements the UserLookup interface by invoking the os/user
// library
type OsUserLookup struct{}
func NewOsUserLookup() UserLookup {
        return &OsUserLookup{}
}
func (OsUserLookup) Lookup(username string) (*user.User, error) { return user.Lookup(username) }
// PolicyLoader contains methods to read/write the opkssh policy file from/to an
// arbitrary filesystem. All methods that read policy from the filesystem fail
// and return an error immediately if the permission bits are invalid.
type PolicyLoader struct {
        FileLoader files.FileLoader
        UserLookup UserLookup
}
func (l PolicyLoader) CreateIfDoesNotExist(path string) error {
        return l.FileLoader.CreateIfDoesNotExist(path)
}
// LoadPolicyAtPath validates that the policy file at path exists, can be read
// by the current process, and has the correct permission bits set. Parses the
// contents and returns a policy.Policy if file permissions are valid and
// reading is successful; otherwise returns an error.
func (l *PolicyLoader) LoadPolicyAtPath(path string) (*Policy, error) {
        content, err := l.FileLoader.LoadFileAtPath(path)
        if err != nil {
                return nil, err
        }
        policy := FromTable(content, path)
        return policy, nil
}
// Dump encodes the policy into file and writes the contents to the filepath
// path
func (l *PolicyLoader) Dump(policy *Policy, path string) error {
        fileBytes, err := policy.ToTable()
        if err != nil {
                return err
        }
        // Write to disk
        if err := l.FileLoader.Dump(fileBytes, path); err != nil {
                return fmt.Errorf("failed to write to policy file %s: %w", path, err)
        }
        return nil
}
// NewSystemPolicyLoader returns an opkssh policy loader that uses the os library to
// read/write system policy from/to the filesystem.
func NewSystemPolicyLoader() *SystemPolicyLoader {
        return &SystemPolicyLoader{
                PolicyLoader: &PolicyLoader{
                        FileLoader: files.FileLoader{
                                Fs:           afero.NewOsFs(),
                                RequiredPerm: files.ModeSystemPerms,
                        },
                        UserLookup: NewOsUserLookup(),
                },
        }
}
// SystemPolicyLoader contains methods to read/write the system wide  opkssh policy file
// from/to a filesystem. All methods that read policy from the filesystem fail
// and return an error immediately if the permission bits are invalid.
type SystemPolicyLoader struct {
        *PolicyLoader
}
// LoadSystemPolicy reads the opkssh policy at SystemDefaultPolicyPath.
// An error is returned if the file cannot be read or if the permissions bits
// are not correct.
func (s *SystemPolicyLoader) LoadSystemPolicy() (*Policy, Source, error) {
        policy, err := s.LoadPolicyAtPath(SystemDefaultPolicyPath)
        if err != nil {
                return nil, EmptySource{}, fmt.Errorf("failed to read system default policy file %s: %w", SystemDefaultPolicyPath, err)
        }
        return policy, FileSource(SystemDefaultPolicyPath), nil
}
type OptionalLoader func(h *HomePolicyLoader, username string) ([]byte, error)
// HomePolicyLoader contains methods to read/write the opkssh policy file stored in
// `~/.opk/ssh` from/to a filesystem. All methods that read policy from the filesystem fail
// and return an error immediately if the permission bits are invalid.
type HomePolicyLoader struct {
        *PolicyLoader
}
// NewHomePolicyLoader returns an opkssh policy loader that uses the os library to
// read/write policy from/to the user's home directory, e.g. `~/.opk/auth_id`,
func NewHomePolicyLoader() *HomePolicyLoader {
        return &HomePolicyLoader{
                PolicyLoader: &PolicyLoader{
                        FileLoader: files.FileLoader{
                                Fs:           afero.NewOsFs(),
                                RequiredPerm: files.ModeHomePerms,
                        },
                        UserLookup: NewOsUserLookup(),
                },
        }
}
// LoadHomePolicy reads the user's opkssh policy at ~/.opk/auth_id (where ~
// maps to username's home directory) and returns the filepath read. An error is
// returned if the file cannot be read, if the permission bits are not correct,
// or if there is no user with username or has no home directory.
//
// If skipInvalidEntries is true, then invalid user entries are skipped and not
// included in the returned policy. A user policy's entry is considered valid if
// it gives username access. The returned policy is stripped of invalid entries.
// To specify an alternative Loader that will be used if we don't have sufficient
// permissions to read the policy file in the user's home directory, pass the
// alternative loader as the last argument.
func (h *HomePolicyLoader) LoadHomePolicy(username string, skipInvalidEntries bool, optLoader ...OptionalLoader) (*Policy, string, error) {
        policyFilePath, err := h.UserPolicyPath(username)
        if err != nil {
                return nil, "", fmt.Errorf("error getting user policy path for user %s: %w", username, err)
        }
        policyBytes, userPolicyErr := h.FileLoader.LoadFileAtPath(policyFilePath)
        if userPolicyErr != nil {
                if len(optLoader) == 1 {
                        // Try to read using the optional loader
                        policyBytes, err = optLoader[0](h, username)
                        if err != nil {
                                return nil, "", fmt.Errorf("failed to read user policy file %s: %w", policyFilePath, err)
                        }
                } else if len(optLoader) > 1 {
                        return nil, "", fmt.Errorf("only one optional loaders allowed, got %d", len(optLoader))
                } else {
                        return nil, "", fmt.Errorf("failed to read user policy file %s: %w", policyFilePath, userPolicyErr)
                }
        }
        policy := FromTable(policyBytes, policyFilePath)
        if skipInvalidEntries {
                // Build valid user policy. Ignore user entries that give access to a
                // principal not equal to the username where the policy file was read
                // from.
                validUserPolicy := new(Policy)
                for _, user := range policy.Users {
                        if slices.Contains(user.Principals, username) {
                                // Build clean entry that only gives access to username
                                validUserPolicy.Users = append(validUserPolicy.Users, User{
                                        IdentityAttribute: user.IdentityAttribute,
                                        Principals:        []string{username},
                                        Issuer:            user.Issuer,
                                })
                        }
                }
                return validUserPolicy, policyFilePath, nil
        } else {
                // Just return what we read
                return policy, policyFilePath, nil
        }
}
// UserPolicyPath returns the path to the user's opkssh policy file at
// ~/.opk/auth_id.
func (h *HomePolicyLoader) UserPolicyPath(username string) (string, error) {
        user, err := h.UserLookup.Lookup(username)
        if err != nil {
                return "", fmt.Errorf("failed to lookup username %s: %w", username, err)
        }
        userHomeDirectory := user.HomeDir
        if userHomeDirectory == "" {
                return "", fmt.Errorf("user %s does not have a home directory", username)
        }
        policyFilePath := path.Join(userHomeDirectory, ".opk", "auth_id")
        return policyFilePath, nil
}
		
		// Copyright 2025 OpenPubkey
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package policy
import (
        "fmt"
        "strings"
        "github.com/openpubkey/openpubkey/providers"
        "github.com/openpubkey/openpubkey/verifier"
        "github.com/openpubkey/opkssh/policy/files"
        "github.com/spf13/afero"
)
type ProvidersRow struct {
        Issuer           string
        ClientID         string
        ExpirationPolicy string
}
func (p ProvidersRow) GetExpirationPolicy() (verifier.ExpirationPolicy, error) {
        switch p.ExpirationPolicy {
        case "12h":
                return verifier.ExpirationPolicies.MAX_AGE_12HOURS, nil
        case "24h":
                return verifier.ExpirationPolicies.MAX_AGE_24HOURS, nil
        case "48h":
                return verifier.ExpirationPolicies.MAX_AGE_48HOURS, nil
        case "1week":
                return verifier.ExpirationPolicies.MAX_AGE_1WEEK, nil
        case "oidc":
                return verifier.ExpirationPolicies.OIDC, nil
        case "oidc_refreshed":
                return verifier.ExpirationPolicies.OIDC_REFRESHED, nil
        case "never":
                return verifier.ExpirationPolicies.NEVER_EXPIRE, nil
        default:
                return verifier.ExpirationPolicy{}, fmt.Errorf("invalid expiration policy: %s", p.ExpirationPolicy)
        }
}
func (p ProvidersRow) ToString() string {
        return p.Issuer + " " + p.ClientID + " " + p.ExpirationPolicy
}
type ProviderPolicy struct {
        rows []ProvidersRow
}
func (p *ProviderPolicy) AddRow(row ProvidersRow) {
        p.rows = append(p.rows, row)
}
func (p *ProviderPolicy) CreateVerifier() (*verifier.Verifier, error) {
        pvs := []verifier.ProviderVerifier{}
        var expirationPolicy verifier.ExpirationPolicy
        var err error
        for _, row := range p.rows {
                var provider verifier.ProviderVerifier
                // TODO: We should handle this issuer matching in a more generic way
                // oidc.local and localhost: are a test issuers
                if row.Issuer == "https://accounts.google.com" ||
                        strings.HasPrefix(row.Issuer, "http://oidc.local") ||
                        strings.HasPrefix(row.Issuer, "http://localhost:") {
                        opts := providers.GetDefaultGoogleOpOptions()
                        opts.Issuer = row.Issuer
                        opts.ClientID = row.ClientID
                        provider = providers.NewGoogleOpWithOptions(opts)
                } else if strings.HasPrefix(row.Issuer, "https://login.microsoftonline.com") {
                        opts := providers.GetDefaultAzureOpOptions()
                        opts.Issuer = row.Issuer
                        opts.ClientID = row.ClientID
                        provider = providers.NewAzureOpWithOptions(opts)
                } else if row.Issuer == "https://gitlab.com" {
                        opts := providers.GetDefaultGitlabOpOptions()
                        opts.Issuer = row.Issuer
                        opts.ClientID = row.ClientID
                        provider = providers.NewGitlabOpWithOptions(opts)
                } else if row.Issuer == "https://token.actions.githubusercontent.com" {
                        provider = providers.NewGithubOp(row.Issuer, "")
                } else {
                        opts := providers.GetDefaultGoogleOpOptions()
                        opts.Issuer = row.Issuer
                        opts.ClientID = row.ClientID
                        provider = providers.NewGoogleOpWithOptions(opts)
                }
                expirationPolicy, err = row.GetExpirationPolicy()
                if err != nil {
                        return nil, err
                }
                pv := verifier.ProviderVerifierExpires{
                        ProviderVerifier: provider,
                        Expiration:       expirationPolicy,
                }
                pvs = append(pvs, pv)
        }
        if len(pvs) == 0 {
                return nil, fmt.Errorf("no providers configured")
        }
        pktVerifier, err := verifier.NewFromMany(
                pvs,
                verifier.WithExpirationPolicy(expirationPolicy),
        )
        if err != nil {
                return nil, err
        }
        return pktVerifier, nil
}
func (p ProviderPolicy) ToString() string {
        var sb strings.Builder
        for _, row := range p.rows {
                sb.WriteString(row.ToString() + "\n")
        }
        return sb.String()
}
type ProvidersFileLoader struct {
        files.FileLoader
        Path string
}
func NewProviderFileLoader() *ProvidersFileLoader {
        return &ProvidersFileLoader{
                FileLoader: files.FileLoader{
                        Fs:           afero.NewOsFs(),
                        RequiredPerm: files.ModeSystemPerms,
                },
        }
}
func (o *ProvidersFileLoader) LoadProviderPolicy(path string) (*ProviderPolicy, error) {
        content, err := o.LoadFileAtPath(path)
        if err != nil {
                return nil, err
        }
        policy := o.FromTable(content, path)
        return policy, nil
}
// FromTable decodes whitespace delimited input into policy.Policy
func (o ProvidersFileLoader) ToTable(opPolicies ProviderPolicy) files.Table {
        table := files.Table{}
        for _, opPolicy := range opPolicies.rows {
                table.AddRow(opPolicy.Issuer, opPolicy.ClientID, opPolicy.ExpirationPolicy)
        }
        return table
}
// FromTable decodes whitespace delimited input into policy.Policy
// Path is passed only for logging purposes
func (o *ProvidersFileLoader) FromTable(input []byte, path string) *ProviderPolicy {
        table := files.NewTable(input)
        policy := &ProviderPolicy{
                rows: []ProvidersRow{},
        }
        for i, row := range table.GetRows() {
                // Error should not break everyone's ability to login, skip those rows
                if len(row) != 3 {
                        configProblem := files.ConfigProblem{
                                Filepath:            path,
                                OffendingLine:       strings.Join(row, " "),
                                OffendingLineNumber: i,
                                ErrorMessage:        fmt.Sprintf("wrong number of arguments (expected=3, got=%d)", len(row)),
                                Source:              "providers policy file",
                        }
                        files.ConfigProblems().RecordProblem(configProblem)
                        continue
                }
                policyRow := ProvidersRow{
                        Issuer:           row[0],
                        ClientID:         row[1],
                        ExpirationPolicy: row[2], // TODO: Validate this so that we can determine the line number that has the error
                }
                policy.AddRow(policyRow)
        }
        return policy
}
		
		// Copyright 2024 OpenPubkey
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package sshcert
import (
        "context"
        "crypto/rand"
        "encoding/json"
        "fmt"
        "time"
        "github.com/lestrrat-go/jwx/v2/jwk"
        "github.com/openpubkey/openpubkey/pktoken"
        "github.com/openpubkey/openpubkey/verifier"
        "golang.org/x/crypto/ssh"
)
type SshCertSmuggler struct {
        SshCert *ssh.Certificate
}
func New(pkt *pktoken.PKToken, accessToken []byte, principals []string) (*SshCertSmuggler, error) {
        // TODO: assumes email exists in ID Token,
        // this will break for OPs like Azure that do not have email as a claim
        var claims struct {
                Email string `json:"email"`
        }
        if err := json.Unmarshal(pkt.Payload, &claims); err != nil {
                return nil, err
        }
        pubkeySsh, err := sshPubkeyFromPKT(pkt)
        if err != nil {
                return nil, err
        }
        pktCom, err := pkt.Compact()
        if err != nil {
                return nil, err
        }
        extensions := map[string]string{
                "permit-X11-forwarding":   "",
                "permit-agent-forwarding": "",
                "permit-port-forwarding":  "",
                "permit-pty":              "",
                "permit-user-rc":          "",
                "openpubkey-pkt":          string(pktCom),
        }
        if accessToken != nil {
                extensions["openpubkey-act"] = string(accessToken)
        }
        sshSmuggler := SshCertSmuggler{
                SshCert: &ssh.Certificate{
                        Key:             pubkeySsh,
                        CertType:        ssh.UserCert,
                        KeyId:           claims.Email,
                        ValidPrincipals: principals,
                        ValidBefore:     ssh.CertTimeInfinity,
                        Permissions: ssh.Permissions{
                                Extensions: extensions,
                        },
                },
        }
        return &sshSmuggler, nil
}
func NewFromAuthorizedKey(certType string, certB64 string) (*SshCertSmuggler, error) {
        if certPubkey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(certType + " " + certB64)); err != nil {
                return nil, err
        } else {
                sshCert, ok := certPubkey.(*ssh.Certificate)
                if !ok {
                        return nil, fmt.Errorf("parsed SSH authorized_key is not an SSH certificate")
                }
                opkcert := &SshCertSmuggler{
                        SshCert: sshCert,
                }
                return opkcert, nil
        }
}
func (s *SshCertSmuggler) SignCert(signerMas ssh.MultiAlgorithmSigner) (*ssh.Certificate, error) {
        if err := s.SshCert.SignCert(rand.Reader, signerMas); err != nil {
                return nil, err
        }
        return s.SshCert, nil
}
func (s *SshCertSmuggler) VerifyCaSig(caPubkey ssh.PublicKey) error {
        certCopy := *(s.SshCert)
        certCopy.Signature = nil
        certBytes := certCopy.Marshal()
        certBytes = certBytes[:len(certBytes)-4] // Drops signature length bytes (see crypto.ssh.certs.go)
        return caPubkey.Verify(certBytes, s.SshCert.Signature)
}
func (s *SshCertSmuggler) GetPKToken() (*pktoken.PKToken, error) {
        pktCom, ok := s.SshCert.Extensions["openpubkey-pkt"]
        if !ok {
                return nil, fmt.Errorf("cert is missing required openpubkey-pkt extension")
        }
        pkt, err := pktoken.NewFromCompact([]byte(pktCom))
        if err != nil {
                return nil, fmt.Errorf("openpubkey-pkt extension in cert failed deserialization: %w", err)
        }
        return pkt, nil
}
func (s *SshCertSmuggler) GetAccessToken() string {
        // Generally we don't expect this to be set, but if it is, we return it
        if accessToken, ok := s.SshCert.Extensions["openpubkey-act"]; ok {
                return accessToken
        }
        return ""
}
func (s *SshCertSmuggler) VerifySshPktCert(ctx context.Context, pktVerifier verifier.Verifier) (*pktoken.PKToken, error) {
        pkt, err := s.GetPKToken()
        if err != nil {
                return nil, fmt.Errorf("openpubkey-pkt extension in cert failed deserialization: %w", err)
        }
        ctxWithTimeout, cancel := context.WithTimeout(ctx, 30*time.Second)
        defer cancel()
        err = pktVerifier.VerifyPKToken(ctxWithTimeout, pkt)
        if err != nil {
                return nil, err
        }
        cic, err := pkt.GetCicValues()
        if err != nil {
                return nil, err
        }
        upk := cic.PublicKey()
        cryptoCertKey := (s.SshCert.Key.(ssh.CryptoPublicKey)).CryptoPublicKey()
        jwkCertKey, err := jwk.FromRaw(cryptoCertKey)
        if err != nil {
                return nil, err
        }
        if jwk.Equal(jwkCertKey, upk) {
                return pkt, nil
        } else {
                return nil, fmt.Errorf("public key 'upk' in PK Token does not match public key in certificate")
        }
}
func sshPubkeyFromPKT(pkt *pktoken.PKToken) (ssh.PublicKey, error) {
        cic, err := pkt.GetCicValues()
        if err != nil {
                return nil, err
        }
        upk := cic.PublicKey()
        var rawkey any
        if err := upk.Raw(&rawkey); err != nil {
                return nil, err
        }
        return ssh.NewPublicKey(rawkey)
}