Skip to content

Instantly share code, notes, and snippets.

@nobiit
Created November 24, 2024 05:05
Show Gist options
  • Select an option

  • Save nobiit/caceebd95a45b8312da9ba97bc3bcf79 to your computer and use it in GitHub Desktop.

Select an option

Save nobiit/caceebd95a45b8312da9ba97bc3bcf79 to your computer and use it in GitHub Desktop.
package keyring
import (
"errors"
"fmt"
"go.nobidev.com/ssh-agent/pkg/api"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)
var _ api.KeyringSession = (*withSession)(nil)
type withSession struct {
api.KeyringContext
hostKey ssh.PublicKey
sessionID []byte
}
func WithSession(k api.KeyringContext) api.KeyringContext {
return &withSession{
KeyringContext: k,
}
}
func (s *withSession) getHosts(key ssh.PublicKey) []string {
if v, ok := s.KeyringContext.(api.KeyringTrust); ok {
w, _ := v.Trust(key)
return w
}
return nil
}
func (s *withSession) sessionBind(contents []byte) ([]byte, error) {
var req struct {
HostKey []byte
Session []byte
Signature []byte
IsForwarding bool
}
err := ssh.Unmarshal(contents, &req)
if err != nil {
return nil, err
}
if s.hostKey, err = ssh.ParsePublicKey(req.HostKey); err != nil {
return nil, err
}
signature, signRest, ok := ssh.ParseSignatureBody(req.Signature)
if !ok || len(signRest) > 0 {
return nil, errors.New("invalid signature")
}
err = s.hostKey.Verify(req.Session, signature)
if err != nil {
return nil, err
}
s.sessionID = req.Session
return nil, nil
}
func (s *withSession) Extension(extensionType string, contents []byte) ([]byte, error) {
r, err := s.KeyringContext.Extension(extensionType, contents)
if !errors.Is(err, agent.ErrExtensionUnsupported) {
return r, err
}
switch extensionType {
case "[email protected]":
return s.sessionBind(contents)
default:
return nil, agent.ErrExtensionUnsupported
}
}
func (s *withSession) SessionID() string {
return fmt.Sprintf("%x", s.sessionID)
}
func (s *withSession) Hosts() []string {
if s.hostKey != nil {
return s.getHosts(s.hostKey)
}
return nil
}
func (s *withSession) ProcessName() string {
c := GetContextKey[*ClientInfo](s.KeyringContext, clientInfoContextKey)
if c == nil || c.Process == nil {
return ""
}
name, err := c.Process.Cmdline()
if err != nil {
log.Warn("Process.Name: %v", err)
}
return name
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment