Skip to content

Commit

Permalink
add principals from metadata in ssh cert request
Browse files Browse the repository at this point in the history
Signed-off-by: Abhijeet V <[email protected]>
  • Loading branch information
abvaidya committed May 19, 2023
1 parent 00a9f01 commit 5d7cbd6
Show file tree
Hide file tree
Showing 14 changed files with 118 additions and 10 deletions.
17 changes: 13 additions & 4 deletions libs/go/sia/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ func registerSvc(svc options.Service, ztsUrl, metaEndpoint string, opts *options
//if ssh support is enabled then we need to generate the csr
//it is also generated for the primary service only
hostname := getServiceHostname(opts, svc, false)
sshCertRequest, sshCsr, err := generateSshRequest(opts, svc.Name, hostname)
sshCertRequest, sshCsr, err := generateSshRequest(opts, svc.Name, hostname, metaEndpoint)
if err != nil {
return err
}
Expand Down Expand Up @@ -335,7 +335,7 @@ func refreshSvc(svc options.Service, ztsUrl, metaEndpoint string, opts *options.
//if ssh support is enabled then we need to generate the csr
//it is also generated for the primary service only
hostname := getServiceHostname(opts, svc, false)
sshCertRequest, sshCsr, err := generateSshRequest(opts, svc.Name, hostname)
sshCertRequest, sshCsr, err := generateSshRequest(opts, svc.Name, hostname, metaEndpoint)
if err != nil {
return err
}
Expand Down Expand Up @@ -419,15 +419,24 @@ func refreshSvc(svc options.Service, ztsUrl, metaEndpoint string, opts *options.
return nil
}

func generateSshRequest(opts *options.Options, primaryServiceName, hostname string) (*zts.SSHCertRequest, string, error) {
func generateSshRequest(opts *options.Options, primaryServiceName, hostname, metaEndpoint string) (*zts.SSHCertRequest, string, error) {
var err error
var sshCsr string
var sshCertRequest *zts.SSHCertRequest
if opts.Ssh && opts.Services[0].Name == primaryServiceName {
if opts.SshHostKeyType == hostkey.Rsa {
sshCsr, err = util.GenerateSSHHostCSR(opts.SshPubKeyFile, opts.Domain, primaryServiceName, opts.PrivateIp, opts.ZTSCloudDomains)
} else {
sshCertRequest, err = util.GenerateSSHHostRequest(opts.SshPubKeyFile, opts.Domain, primaryServiceName, hostname, opts.PrivateIp, opts.InstanceId, opts.SshPrincipals, opts.ZTSCloudDomains)
sshPrincipals := opts.SshPrincipals
// additional ssh host principals are added on best effort basis, hence error below is ignored.
additionalSshHostPrincipals, _ := opts.Provider.GetAdditionalSshHostPrincipals(metaEndpoint)
if additionalSshHostPrincipals != "" {
if sshPrincipals != "" {
sshPrincipals = sshPrincipals + "," + additionalSshHostPrincipals
}
sshPrincipals = additionalSshHostPrincipals
}
sshCertRequest, err = util.GenerateSSHHostRequest(opts.SshPubKeyFile, opts.Domain, primaryServiceName, hostname, opts.PrivateIp, opts.InstanceId, sshPrincipals, opts.ZTSCloudDomains)
}
}
return sshCertRequest, sshCsr, err
Expand Down
34 changes: 29 additions & 5 deletions libs/go/sia/agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"crypto/x509/pkix"
"fmt"
"github.com/AthenZ/athenz/libs/go/sia/ssh/hostkey"
"k8s.io/utils/strings/slices"
"log"
"net"
"net/url"
Expand Down Expand Up @@ -115,6 +116,10 @@ func (tp TestProvider) GetAccessManagementProfileFromMeta(string) (string, error
return "testProf", nil
}

func (tp TestProvider) GetAdditionalSshHostPrincipals(string) (string, error) {
return "my-vm,my-instance-id", nil
}

func TestUpdateFileNew(test *testing.T) {
testInternalUpdateFileNew(test, true)
testInternalUpdateFileNew(test, false)
Expand Down Expand Up @@ -561,11 +566,15 @@ func TestServiceAlreadyRegistered(test *testing.T) {

func TestGenerateSshRequest(test *testing.T) {

tp := TestProvider{
Name: "athenz.aws.us-west-2",
}
opts := options.Options{
Ssh: false,
Ssh: false,
Provider: tp,
}
// ssh option false we should get success with nils and empty csr
sshReq, sshCsr, err := generateSshRequest(&opts, "backend", "hostname.athenz.io")
sshReq, sshCsr, err := generateSshRequest(&opts, "backend", "hostname.athenz.io", "")
assert.Nil(test, sshReq)
assert.Equal(test, "", sshCsr)
assert.Nil(test, err)
Expand All @@ -576,7 +585,7 @@ func TestGenerateSshRequest(test *testing.T) {
Name: "api",
},
}
sshReq, sshCsr, err = generateSshRequest(&opts, "backend", "hostname.athenz.io")
sshReq, sshCsr, err = generateSshRequest(&opts, "backend", "hostname.athenz.io", "")
assert.Nil(test, sshReq)
assert.Equal(test, "", sshCsr)
assert.Nil(test, err)
Expand All @@ -585,14 +594,29 @@ func TestGenerateSshRequest(test *testing.T) {
opts.Domain = "athenz"
opts.ZTSAWSDomains = []string{"athenz.io"}
opts.SshHostKeyType = hostkey.Rsa
sshReq, sshCsr, err = generateSshRequest(&opts, "api", "hostname.athenz.io")
sshReq, sshCsr, err = generateSshRequest(&opts, "api", "hostname.athenz.io", "")
assert.Nil(test, sshReq)
assert.NotEmpty(test, sshCsr)
assert.Nil(test, err)
// ssh enabled with primary service and key type is ecdsa - empty csr but not-nil cert request
opts.SshHostKeyType = hostkey.Ecdsa
sshReq, sshCsr, err = generateSshRequest(&opts, "api", "hostname.athenz.io")
sshReq, sshCsr, err = generateSshRequest(&opts, "api", "hostname.athenz.io", "")
assert.NotNil(test, sshReq)
assert.Equal(test, 3, len(sshReq.CertRequestData.Principals))
assert.True(test, slices.Contains(sshReq.CertRequestData.Principals, "my-vm"))
assert.True(test, slices.Contains(sshReq.CertRequestData.Principals, "my-instance-id"))
assert.Empty(test, sshCsr)
assert.Nil(test, err)
// ssh enabled with primary service and key type is ecdsa - empty csr but not-nil cert request, opts defines sshPrincipals
opts.SshHostKeyType = hostkey.Ecdsa
opts.SshPrincipals = "cname.athenz.io"
sshReq, sshCsr, err = generateSshRequest(&opts, "api", "hostname.athenz.io", "")
assert.NotNil(test, sshReq)
assert.Equal(test, 4, len(sshReq.CertRequestData.Principals))
assert.True(test, slices.Contains(sshReq.CertRequestData.Principals, "hostname.athenz.io"))
assert.True(test, slices.Contains(sshReq.CertRequestData.Principals, "cname.athenz.io"))
assert.True(test, slices.Contains(sshReq.CertRequestData.Principals, "my-vm"))
assert.True(test, slices.Contains(sshReq.CertRequestData.Principals, "my-instance-id"))
assert.Empty(test, sshCsr)
assert.Nil(test, err)
}
Expand Down
4 changes: 4 additions & 0 deletions libs/go/sia/aws/agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ func (tp TestProvider) GetAccessManagementProfileFromMeta(string) (string, error
return "testProf", nil
}

func (tp TestProvider) GetAdditionalSshHostPrincipals(string) (string, error) {
return "i-1234edt22", nil
}

func TestUpdateFileNew(test *testing.T) {
testInternalUpdateFileNew(test, true)
testInternalUpdateFileNew(test, false)
Expand Down
8 changes: 7 additions & 1 deletion libs/go/sia/gcp/meta/meta.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,17 @@ func GetInstancePrivateIp(metaEndpoint string) (string, error) {
}
return string(instanceIdBytes), nil
}

func GetInstancePublicIp(metaEndpoint string) (string, error) {
pubIpBytes, err := GetData(metaEndpoint, "/computeMetadata/v1/instance/network-interfaces/0/access-configs/0/external-ip")
if err != nil {
return "", err
}
return string(pubIpBytes), nil
}
func GetInstanceName(metaEndpoint string) (string, error) {
nameBytes, err := GetData(metaEndpoint, "/computeMetadata/v1/instance/name")
if err != nil {
return "", err
}
return string(nameBytes), nil
}
18 changes: 18 additions & 0 deletions libs/go/sia/gcp/meta/meta_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,3 +215,21 @@ func TestGetInstancePublicIp(test *testing.T) {
test.Errorf("want instancePubIp=20.20.20.20 got instancePubIp=%s", instancePubIp)
}
}

func TestGetInstanceName(test *testing.T) {
// Mock the metadata endpoints
router := httptreemux.New()
router.GET("/computeMetadata/v1/instance/name", func(w http.ResponseWriter, r *http.Request, params map[string]string) {
log.Println("Called /computeMetadata/v1/instance/name")
io.WriteString(w, "my-vm")
})

metaServer := &testServer{}
metaServer.start(router)
defer metaServer.stop()

instanceName, _ := GetInstanceName(metaServer.httpUrl())
if instanceName != "my-vm" {
test.Errorf("want instanceName=my-vm got instanceName=%s", instanceName)
}
}
3 changes: 3 additions & 0 deletions libs/go/sia/host/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,7 @@ type Provider interface {

// GetAccessManagementProfileFromMeta gets the profile info from the respective cloud
GetAccessManagementProfileFromMeta(string) (string, error)

// GetAdditionalSshHostPrincipals returns additional provider specific principals to be added in ssh host cert
GetAdditionalSshHostPrincipals(string) (string, error)
}
4 changes: 4 additions & 0 deletions libs/go/sia/options/mockawsprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,7 @@ func (tp MockAWSProvider) GetAccountDomainServiceFromMeta(string) (string, strin
func (tp MockAWSProvider) GetAccessManagementProfileFromMeta(string) (string, error) {
return "testProf", nil
}

func (tp MockAWSProvider) GetAdditionalSshHostPrincipals(string) (string, error) {
return "i-1234edt22", nil
}
4 changes: 4 additions & 0 deletions libs/go/sia/options/mockgcpprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,7 @@ func (tp MockGCPProvider) GetAccountDomainServiceFromMeta(base string) (string,
func (tp MockGCPProvider) GetAccessManagementProfileFromMeta(base string) (string, error) {
return "testProf", nil
}

func (tp MockGCPProvider) GetAdditionalSshHostPrincipals(base string) (string, error) {
return "my-vm,compute.1234567890000,my-vm.c.my-gcp-project.internal", nil
}
4 changes: 4 additions & 0 deletions provider/aws/sia-ec2/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,7 @@ func (eks EC2Provider) GetAccountDomainServiceFromMeta(base string) (string, str
func (tp EC2Provider) GetAccessManagementProfileFromMeta(base string) (string, error) {
return "", fmt.Errorf("not implemented")
}

func (tp EC2Provider) GetAdditionalSshHostPrincipals(base string) (string, error) {
return "", nil
}
4 changes: 4 additions & 0 deletions provider/aws/sia-eks/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,7 @@ func (eks EKSProvider) GetAccountDomainServiceFromMeta(base string) (string, str
func (tp EKSProvider) GetAccessManagementProfileFromMeta(base string) (string, error) {
return "", fmt.Errorf("not implemented")
}

func (tp EKSProvider) GetAdditionalSshHostPrincipals(base string) (string, error) {
return "", nil
}
4 changes: 4 additions & 0 deletions provider/aws/sia-fargate/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,7 @@ func (eks FargateProvider) GetAccountDomainServiceFromMeta(base string) (string,
func (tp FargateProvider) GetAccessManagementProfileFromMeta(base string) (string, error) {
return "", fmt.Errorf("not implemented")
}

func (tp FargateProvider) GetAdditionalSshHostPrincipals(base string) (string, error) {
return "", nil
}
4 changes: 4 additions & 0 deletions provider/gcp/sia-gce/devel/metamock/meta.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ var (
sa = `[email protected]`
instanceId = `3692465022399257023`
accessProfile = `access-profile`
instanceName = `my-vm`
)

func StartMetaServer(EndPoint string) {
Expand All @@ -50,6 +51,9 @@ func StartMetaServer(EndPoint string) {
http.HandleFunc("/computeMetadata/v1/instance/attributes/accessProfile", func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, accessProfile)
})
http.HandleFunc("/computeMetadata/v1/instance/name", func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, instanceName)
})

log.Println("Starting GCE Meta Mock listening on: " + EndPoint)
err := http.ListenAndServe(EndPoint, nil)
Expand Down
16 changes: 16 additions & 0 deletions provider/gcp/sia-gce/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,19 @@ func (tp GCEProvider) GetAccessManagementProfileFromMeta(base string) (string, e
}
return profile, nil
}

func (tp GCEProvider) GetAdditionalSshHostPrincipals(base string) (string, error) {
instanceName, err := meta.GetInstanceName(base)
if err != nil {
return "", err
}
project, err := meta.GetProject(base)
if err != nil {
return fmt.Sprintf("%s", instanceName), nil
}
instanceId, err := meta.GetInstanceId(base)
if err != nil {
return fmt.Sprintf("%s,%s.c.%s.internal", instanceName, instanceName, project), nil
}
return fmt.Sprintf("%s,compute.%s,%s.c.%s.internal", instanceName, instanceId, instanceName, project), nil
}
4 changes: 4 additions & 0 deletions provider/gcp/sia-gke/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,7 @@ func (tp GKEProvider) GetAccessManagementProfileFromMeta(base string) (string, e
}
return profile, nil
}

func (tp GKEProvider) GetAdditionalSshHostPrincipals(base string) (string, error) {
return "", nil
}

0 comments on commit 5d7cbd6

Please sign in to comment.