Commit 0c4878f1 authored by Tongyao Si's avatar Tongyao Si
Browse files

feat(Azure): support specifying user assigned identity's clientID to authenticate

parent 97421cbd
Showing with 154 additions and 131 deletions
+154 -131
......@@ -134,7 +134,7 @@ func main() {
}
p, err = provider.NewAWSSDProvider(domainFilter, cfg.AWSZoneType, cfg.AWSAssumeRole, cfg.DryRun)
case "azure-dns", "azure":
p, err = provider.NewAzureProvider(cfg.AzureConfigFile, domainFilter, zoneIDFilter, cfg.AzureResourceGroup, cfg.DryRun)
p, err = provider.NewAzureProvider(cfg.AzureConfigFile, domainFilter, zoneIDFilter, cfg.AzureResourceGroup, cfg.AzureUserAssignedIdentityClientID, cfg.DryRun)
case "azure-private-dns":
p, err = provider.NewAzurePrivateDNSProvider(domainFilter, zoneIDFilter, cfg.AzureResourceGroup, cfg.AzureSubscriptionID, cfg.DryRun)
case "vinyldns":
......
......@@ -37,97 +37,98 @@ var (
// Config is a project-wide configuration
type Config struct {
Master string
KubeConfig string
RequestTimeout time.Duration
IstioIngressGatewayServices []string
ContourLoadBalancerService string
Sources []string
Namespace string
AnnotationFilter string
FQDNTemplate string
CombineFQDNAndAnnotation bool
IgnoreHostnameAnnotation bool
Compatibility string
PublishInternal bool
PublishHostIP bool
ConnectorSourceServer string
Provider string
GoogleProject string
GoogleBatchChangeSize int
GoogleBatchChangeInterval time.Duration
DomainFilter []string
ExcludeDomains []string
ZoneIDFilter []string
AlibabaCloudConfigFile string
AlibabaCloudZoneType string
AWSZoneType string
AWSZoneTagFilter []string
AWSAssumeRole string
AWSBatchChangeSize int
AWSBatchChangeInterval time.Duration
AWSEvaluateTargetHealth bool
AWSAPIRetries int
AWSPreferCNAME bool
AzureConfigFile string
AzureResourceGroup string
AzureSubscriptionID string
CloudflareProxied bool
CloudflareZonesPerPage int
CoreDNSPrefix string
RcodezeroTXTEncrypt bool
InfobloxGridHost string
InfobloxWapiPort int
InfobloxWapiUsername string
InfobloxWapiPassword string `secure:"yes"`
InfobloxWapiVersion string
InfobloxSSLVerify bool
InfobloxView string
InfobloxMaxResults int
DynCustomerName string
DynUsername string
DynPassword string `secure:"yes"`
DynMinTTLSeconds int
OCIConfigFile string
InMemoryZones []string
PDNSServer string
PDNSAPIKey string `secure:"yes"`
PDNSTLSEnabled bool
TLSCA string
TLSClientCert string
TLSClientCertKey string
Policy string
Registry string
TXTOwnerID string
TXTPrefix string
Interval time.Duration
Once bool
DryRun bool
LogFormat string
MetricsAddress string
LogLevel string
TXTCacheInterval time.Duration
ExoscaleEndpoint string
ExoscaleAPIKey string `secure:"yes"`
ExoscaleAPISecret string `secure:"yes"`
CRDSourceAPIVersion string
CRDSourceKind string
ServiceTypeFilter []string
CFAPIEndpoint string
CFUsername string
CFPassword string
RFC2136Host string
RFC2136Port int
RFC2136Zone string
RFC2136Insecure bool
RFC2136TSIGKeyName string
RFC2136TSIGSecret string `secure:"yes"`
RFC2136TSIGSecretAlg string
RFC2136TAXFR bool
NS1Endpoint string
NS1IgnoreSSL bool
TransIPAccountName string
TransIPPrivateKeyFile string
Master string
KubeConfig string
RequestTimeout time.Duration
IstioIngressGatewayServices []string
ContourLoadBalancerService string
Sources []string
Namespace string
AnnotationFilter string
FQDNTemplate string
CombineFQDNAndAnnotation bool
IgnoreHostnameAnnotation bool
Compatibility string
PublishInternal bool
PublishHostIP bool
ConnectorSourceServer string
Provider string
GoogleProject string
GoogleBatchChangeSize int
GoogleBatchChangeInterval time.Duration
DomainFilter []string
ExcludeDomains []string
ZoneIDFilter []string
AlibabaCloudConfigFile string
AlibabaCloudZoneType string
AWSZoneType string
AWSZoneTagFilter []string
AWSAssumeRole string
AWSBatchChangeSize int
AWSBatchChangeInterval time.Duration
AWSEvaluateTargetHealth bool
AWSAPIRetries int
AWSPreferCNAME bool
AzureConfigFile string
AzureResourceGroup string
AzureSubscriptionID string
AzureUserAssignedIdentityClientID string
CloudflareProxied bool
CloudflareZonesPerPage int
CoreDNSPrefix string
RcodezeroTXTEncrypt bool
InfobloxGridHost string
InfobloxWapiPort int
InfobloxWapiUsername string
InfobloxWapiPassword string `secure:"yes"`
InfobloxWapiVersion string
InfobloxSSLVerify bool
InfobloxView string
InfobloxMaxResults int
DynCustomerName string
DynUsername string
DynPassword string `secure:"yes"`
DynMinTTLSeconds int
OCIConfigFile string
InMemoryZones []string
PDNSServer string
PDNSAPIKey string `secure:"yes"`
PDNSTLSEnabled bool
TLSCA string
TLSClientCert string
TLSClientCertKey string
Policy string
Registry string
TXTOwnerID string
TXTPrefix string
Interval time.Duration
Once bool
DryRun bool
LogFormat string
MetricsAddress string
LogLevel string
TXTCacheInterval time.Duration
ExoscaleEndpoint string
ExoscaleAPIKey string `secure:"yes"`
ExoscaleAPISecret string `secure:"yes"`
CRDSourceAPIVersion string
CRDSourceKind string
ServiceTypeFilter []string
CFAPIEndpoint string
CFUsername string
CFPassword string
RFC2136Host string
RFC2136Port int
RFC2136Zone string
RFC2136Insecure bool
RFC2136TSIGKeyName string
RFC2136TSIGSecret string `secure:"yes"`
RFC2136TSIGSecretAlg string
RFC2136TAXFR bool
NS1Endpoint string
NS1IgnoreSSL bool
TransIPAccountName string
TransIPPrivateKeyFile string
}
var defaultConfig = &Config{
......@@ -311,6 +312,7 @@ func (cfg *Config) ParseFlags(args []string) error {
app.Flag("azure-config-file", "When using the Azure provider, specify the Azure configuration file (required when --provider=azure").Default(defaultConfig.AzureConfigFile).StringVar(&cfg.AzureConfigFile)
app.Flag("azure-resource-group", "When using the Azure provider, override the Azure resource group to use (required when --provider=azure-private-dns)").Default(defaultConfig.AzureResourceGroup).StringVar(&cfg.AzureResourceGroup)
app.Flag("azure-subscription-id", "When using the Azure provider, specify the Azure configuration file (required when --provider=azure-private-dns)").Default(defaultConfig.AzureSubscriptionID).StringVar(&cfg.AzureSubscriptionID)
app.Flag("azure-user-assigned-identity-client-id", "When using the Azure provider, override the client id of user assigned identity in config file (optional)").Default("").StringVar(&cfg.AzureUserAssignedIdentityClientID)
app.Flag("cloudflare-proxied", "When using the Cloudflare provider, specify if the proxy mode must be enabled (default: disabled)").BoolVar(&cfg.CloudflareProxied)
app.Flag("cloudflare-zones-per-page", "When using the Cloudflare provider, specify how many zones per page listed, max. possible 50 (default: 50)").Default(strconv.Itoa(defaultConfig.CloudflareZonesPerPage)).IntVar(&cfg.CloudflareZonesPerPage)
app.Flag("coredns-prefix", "When using the CoreDNS provider, specify the prefix name").Default(defaultConfig.CoreDNSPrefix).StringVar(&cfg.CoreDNSPrefix)
......
......@@ -49,6 +49,7 @@ type config struct {
ClientID string `json:"aadClientId" yaml:"aadClientId"`
ClientSecret string `json:"aadClientSecret" yaml:"aadClientSecret"`
UseManagedIdentityExtension bool `json:"useManagedIdentityExtension" yaml:"useManagedIdentityExtension"`
UserAssignedIdentityID string `json:"userAssignedIdentityID" yaml:"userAssignedIdentityID"`
}
// ZonesClient is an interface of dns.ZoneClient that can be stubbed for testing.
......@@ -65,18 +66,19 @@ type RecordSetsClient interface {
// AzureProvider implements the DNS provider for Microsoft's Azure cloud platform.
type AzureProvider struct {
domainFilter DomainFilter
zoneIDFilter ZoneIDFilter
dryRun bool
resourceGroup string
zonesClient ZonesClient
recordSetsClient RecordSetsClient
domainFilter DomainFilter
zoneIDFilter ZoneIDFilter
dryRun bool
resourceGroup string
userAssignedIdentityClientID string
zonesClient ZonesClient
recordSetsClient RecordSetsClient
}
// NewAzureProvider creates a new Azure provider.
//
// Returns the provider or an error if a provider could not be created.
func NewAzureProvider(configFile string, domainFilter DomainFilter, zoneIDFilter ZoneIDFilter, resourceGroup string, dryRun bool) (*AzureProvider, error) {
func NewAzureProvider(configFile string, domainFilter DomainFilter, zoneIDFilter ZoneIDFilter, resourceGroup string, userAssignedIdentityClientID string, dryRun bool) (*AzureProvider, error) {
contents, err := ioutil.ReadFile(configFile)
if err != nil {
return nil, fmt.Errorf("failed to read Azure config file '%s': %v", configFile, err)
......@@ -91,6 +93,10 @@ func NewAzureProvider(configFile string, domainFilter DomainFilter, zoneIDFilter
if resourceGroup != "" {
cfg.ResourceGroup = resourceGroup
}
// If userAssignedIdentityClientID is provided explicitly, override existing one in config file
if userAssignedIdentityClientID != "" {
cfg.UserAssignedIdentityID = userAssignedIdentityClientID
}
var environment azure.Environment
if cfg.Cloud == "" {
......@@ -113,34 +119,22 @@ func NewAzureProvider(configFile string, domainFilter DomainFilter, zoneIDFilter
recordSetsClient.Authorizer = autorest.NewBearerAuthorizer(token)
provider := &AzureProvider{
domainFilter: domainFilter,
zoneIDFilter: zoneIDFilter,
dryRun: dryRun,
resourceGroup: cfg.ResourceGroup,
zonesClient: zonesClient,
recordSetsClient: recordSetsClient,
domainFilter: domainFilter,
zoneIDFilter: zoneIDFilter,
dryRun: dryRun,
resourceGroup: cfg.ResourceGroup,
userAssignedIdentityClientID: cfg.UserAssignedIdentityID,
zonesClient: zonesClient,
recordSetsClient: recordSetsClient,
}
return provider, nil
}
// getAccessToken retrieves Azure API access token.
func getAccessToken(cfg config, environment azure.Environment) (*adal.ServicePrincipalToken, error) {
// Try to retrive token with MSI.
if cfg.UseManagedIdentityExtension {
log.Info("Using managed identity extension to retrieve access token for Azure API.")
msiEndpoint, err := adal.GetMSIVMEndpoint()
if err != nil {
return nil, fmt.Errorf("failed to get the managed service identity endpoint: %v", err)
}
token, err := adal.NewServicePrincipalTokenFromMSI(msiEndpoint, environment.ServiceManagementEndpoint)
if err != nil {
return nil, fmt.Errorf("failed to create the managed service identity token: %v", err)
}
return token, nil
}
// Try to retrieve token with service principal credentials.
// Try to use service principal first, some AKS clusters are in an intermediate state that `UseManagedIdentityExtension` is `true`
// and service principal exists. In this case, we still want to use service principal to authenticate.
if len(cfg.ClientID) > 0 && len(cfg.ClientSecret) > 0 {
log.Info("Using client_id+client_secret to retrieve access token for Azure API.")
oauthConfig, err := adal.NewOAuthConfig(environment.ActiveDirectoryEndpoint, cfg.TenantID)
......@@ -155,6 +149,31 @@ func getAccessToken(cfg config, environment azure.Environment) (*adal.ServicePri
return token, nil
}
// Try to retrive token with MSI.
if cfg.UseManagedIdentityExtension {
log.Info("Using managed identity extension to retrieve access token for Azure API.")
msiEndpoint, err := adal.GetMSIVMEndpoint()
if err != nil {
return nil, fmt.Errorf("failed to get the managed service identity endpoint: %v", err)
}
if cfg.UserAssignedIdentityID != "" {
log.Infof("Resolving to user assigned identity, client id is %s.", cfg.UserAssignedIdentityID)
token, err := adal.NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, environment.ServiceManagementEndpoint, cfg.UserAssignedIdentityID)
if err != nil {
return nil, fmt.Errorf("failed to create the managed service identity token: %v", err)
}
return token, nil
} else {
log.Info("Resolving to system assigned identity.")
token, err := adal.NewServicePrincipalTokenFromMSI(msiEndpoint, environment.ServiceManagementEndpoint)
if err != nil {
return nil, fmt.Errorf("failed to create the managed service identity token: %v", err)
}
return token, nil
}
}
return nil, fmt.Errorf("no credentials provided for Azure API")
}
......
......@@ -205,7 +205,7 @@ func (client *mockRecordSetsClient) CreateOrUpdate(ctx context.Context, resource
}
// newMockedAzureProvider creates an AzureProvider comprising the mocked clients for zones and recordsets
func newMockedAzureProvider(domainFilter DomainFilter, zoneIDFilter ZoneIDFilter, dryRun bool, resourceGroup string, zones *[]dns.Zone, recordSets *[]dns.RecordSet) (*AzureProvider, error) {
func newMockedAzureProvider(domainFilter DomainFilter, zoneIDFilter ZoneIDFilter, dryRun bool, resourceGroup string, userAssignedIdentityClientID string, zones *[]dns.Zone, recordSets *[]dns.RecordSet) (*AzureProvider, error) {
// init zone-related parts of the mock-client
pageIterator := mockZoneListResultPageIterator{
results: []dns.ZoneListResult{
......@@ -235,17 +235,18 @@ func newMockedAzureProvider(domainFilter DomainFilter, zoneIDFilter ZoneIDFilter
mockRecordSetListIterator: &mockRecordSetListIterator,
}
return newAzureProvider(domainFilter, zoneIDFilter, dryRun, resourceGroup, &zonesClient, &recordSetsClient), nil
return newAzureProvider(domainFilter, zoneIDFilter, dryRun, resourceGroup, userAssignedIdentityClientID, &zonesClient, &recordSetsClient), nil
}
func newAzureProvider(domainFilter DomainFilter, zoneIDFilter ZoneIDFilter, dryRun bool, resourceGroup string, zonesClient ZonesClient, recordsClient RecordSetsClient) *AzureProvider {
func newAzureProvider(domainFilter DomainFilter, zoneIDFilter ZoneIDFilter, dryRun bool, resourceGroup string, userAssignedIdentityClientID string, zonesClient ZonesClient, recordsClient RecordsClient) *AzureProvider {
return &AzureProvider{
domainFilter: domainFilter,
zoneIDFilter: zoneIDFilter,
dryRun: dryRun,
resourceGroup: resourceGroup,
zonesClient: zonesClient,
recordSetsClient: recordsClient,
domainFilter: domainFilter,
zoneIDFilter: zoneIDFilter,
dryRun: dryRun,
resourceGroup: resourceGroup,
userAssignedIdentityClientID: userAssignedIdentityClientID,
zonesClient: zonesClient,
recordSetsClient: recordsClient,
}
}
......@@ -254,7 +255,7 @@ func validateAzureEndpoints(t *testing.T, endpoints []*endpoint.Endpoint, expect
}
func TestAzureRecord(t *testing.T) {
provider, err := newMockedAzureProvider(NewDomainFilter([]string{"example.com"}), NewZoneIDFilter([]string{""}), true, "k8s",
provider, err := newMockedAzureProvider(NewDomainFilter([]string{"example.com"}), NewZoneIDFilter([]string{""}), true, "k8s", "",
&[]dns.Zone{
createMockZone("example.com", "/dnszones/example.com"),
},
......@@ -290,7 +291,7 @@ func TestAzureRecord(t *testing.T) {
}
func TestAzureMultiRecord(t *testing.T) {
provider, err := newMockedAzureProvider(NewDomainFilter([]string{"example.com"}), NewZoneIDFilter([]string{""}), true, "k8s",
provider, err := newMockedAzureProvider(NewDomainFilter([]string{"example.com"}), NewZoneIDFilter([]string{""}), true, "k8s", "",
&[]dns.Zone{
createMockZone("example.com", "/dnszones/example.com"),
},
......@@ -392,6 +393,7 @@ func testAzureApplyChangesInternal(t *testing.T, dryRun bool, client RecordSetsC
NewZoneIDFilter([]string{""}),
dryRun,
"group",
"",
&zonesClient,
client,
)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment