Add e2e tests

This commit is contained in:
Manuel de Brito Fontes 2017-10-17 19:50:27 -03:00
parent 99a355f25d
commit 601fb7dacf
1163 changed files with 289217 additions and 14195 deletions

View file

@ -0,0 +1,253 @@
# Azure Active Directory library for Go
This project provides a stand alone Azure Active Directory library for Go. The code was extracted
from [go-autorest](https://github.com/Azure/go-autorest/) project, which is used as a base for
[azure-sdk-for-go](https://github.com/Azure/azure-sdk-for-go).
## Installation
```
go get -u github.com/Azure/go-autorest/autorest/adal
```
## Usage
An Active Directory application is required in order to use this library. An application can be registered in the [Azure Portal](https://portal.azure.com/) follow these [guidelines](https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-integrating-applications) or using the [Azure CLI](https://github.com/Azure/azure-cli).
### Register an Azure AD Application with secret
1. Register a new application with a `secret` credential
```
az ad app create \
--display-name example-app \
--homepage https://example-app/home \
--identifier-uris https://example-app/app \
--password secret
```
2. Create a service principal using the `Application ID` from previous step
```
az ad sp create --id "Application ID"
```
* Replace `Application ID` with `appId` from step 1.
### Register an Azure AD Application with certificate
1. Create a private key
```
openssl genrsa -out "example-app.key" 2048
```
2. Create the certificate
```
openssl req -new -key "example-app.key" -subj "/CN=example-app" -out "example-app.csr"
openssl x509 -req -in "example-app.csr" -signkey "example-app.key" -out "example-app.crt" -days 10000
```
3. Create the PKCS12 version of the certificate containing also the private key
```
openssl pkcs12 -export -out "example-app.pfx" -inkey "example-app.key" -in "example-app.crt" -passout pass:
```
4. Register a new application with the certificate content form `example-app.crt`
```
certificateContents="$(tail -n+2 "example-app.crt" | head -n-1)"
az ad app create \
--display-name example-app \
--homepage https://example-app/home \
--identifier-uris https://example-app/app \
--key-usage Verify --end-date 2018-01-01 \
--key-value "${certificateContents}"
```
5. Create a service principal using the `Application ID` from previous step
```
az ad sp create --id "APPLICATION_ID"
```
* Replace `APPLICATION_ID` with `appId` from step 4.
### Grant the necessary permissions
Azure relies on a Role-Based Access Control (RBAC) model to manage the access to resources at a fine-grained
level. There is a set of [pre-defined roles](https://docs.microsoft.com/en-us/azure/active-directory/role-based-access-built-in-roles)
which can be assigned to a service principal of an Azure AD application depending of your needs.
```
az role assignment create --assigner "SERVICE_PRINCIPAL_ID" --role "ROLE_NAME"
```
* Replace the `SERVICE_PRINCIPAL_ID` with the `appId` from previous step.
* Replace the `ROLE_NAME` with a role name of your choice.
It is also possible to define custom role definitions.
```
az role definition create --role-definition role-definition.json
```
* Check [custom roles](https://docs.microsoft.com/en-us/azure/active-directory/role-based-access-control-custom-roles) for more details regarding the content of `role-definition.json` file.
### Acquire Access Token
The common configuration used by all flows:
```Go
const activeDirectoryEndpoint = "https://login.microsoftonline.com/"
tenantID := "TENANT_ID"
oauthConfig, err := adal.NewOAuthConfig(activeDirectoryEndpoint, tenantID)
applicationID := "APPLICATION_ID"
callback := func(token adal.Token) error {
// This is called after the token is acquired
}
// The resource for which the token is acquired
resource := "https://management.core.windows.net/"
```
* Replace the `TENANT_ID` with your tenant ID.
* Replace the `APPLICATION_ID` with the value from previous section.
#### Client Credentials
```Go
applicationSecret := "APPLICATION_SECRET"
spt, err := adal.NewServicePrincipalToken(
oauthConfig,
appliationID,
applicationSecret,
resource,
callbacks...)
if err != nil {
return nil, err
}
// Acquire a new access token
err = spt.Refresh()
if (err == nil) {
token := spt.Token
}
```
* Replace the `APPLICATION_SECRET` with the `password` value from previous section.
#### Client Certificate
```Go
certificatePath := "./example-app.pfx"
certData, err := ioutil.ReadFile(certificatePath)
if err != nil {
return nil, fmt.Errorf("failed to read the certificate file (%s): %v", certificatePath, err)
}
// Get the certificate and private key from pfx file
certificate, rsaPrivateKey, err := decodePkcs12(certData, "")
if err != nil {
return nil, fmt.Errorf("failed to decode pkcs12 certificate while creating spt: %v", err)
}
spt, err := adal.NewServicePrincipalTokenFromCertificate(
oauthConfig,
applicationID,
certificate,
rsaPrivateKey,
resource,
callbacks...)
// Acquire a new access token
err = spt.Refresh()
if (err == nil) {
token := spt.Token
}
```
* Update the certificate path to point to the example-app.pfx file which was created in previous section.
#### Device Code
```Go
oauthClient := &http.Client{}
// Acquire the device code
deviceCode, err := adal.InitiateDeviceAuth(
oauthClient,
oauthConfig,
applicationID,
resource)
if err != nil {
return nil, fmt.Errorf("Failed to start device auth flow: %s", err)
}
// Display the authentication message
fmt.Println(*deviceCode.Message)
// Wait here until the user is authenticated
token, err := adal.WaitForUserCompletion(oauthClient, deviceCode)
if err != nil {
return nil, fmt.Errorf("Failed to finish device auth flow: %s", err)
}
spt, err := adal.NewServicePrincipalTokenFromManualToken(
oauthConfig,
applicationID,
resource,
*token,
callbacks...)
if (err == nil) {
token := spt.Token
}
```
### Command Line Tool
A command line tool is available in `cmd/adal.go` that can acquire a token for a given resource. It supports all flows mentioned above.
```
adal -h
Usage of ./adal:
-applicationId string
application id
-certificatePath string
path to pk12/PFC application certificate
-mode string
authentication mode (device, secret, cert, refresh) (default "device")
-resource string
resource for which the token is requested
-secret string
application secret
-tenantId string
tenant id
-tokenCachePath string
location of oath token cache (default "/home/cgc/.adal/accessToken.json")
```
Example acquire a token for `https://management.core.windows.net/` using device code flow:
```
adal -mode device \
-applicationId "APPLICATION_ID" \
-tenantId "TENANT_ID" \
-resource https://management.core.windows.net/
```

View file

@ -0,0 +1,65 @@
package adal
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"fmt"
"net/url"
)
const (
activeDirectoryAPIVersion = "1.0"
)
// OAuthConfig represents the endpoints needed
// in OAuth operations
type OAuthConfig struct {
AuthorityEndpoint url.URL
AuthorizeEndpoint url.URL
TokenEndpoint url.URL
DeviceCodeEndpoint url.URL
}
// NewOAuthConfig returns an OAuthConfig with tenant specific urls
func NewOAuthConfig(activeDirectoryEndpoint, tenantID string) (*OAuthConfig, error) {
const activeDirectoryEndpointTemplate = "%s/oauth2/%s?api-version=%s"
u, err := url.Parse(activeDirectoryEndpoint)
if err != nil {
return nil, err
}
authorityURL, err := u.Parse(tenantID)
if err != nil {
return nil, err
}
authorizeURL, err := u.Parse(fmt.Sprintf(activeDirectoryEndpointTemplate, tenantID, "authorize", activeDirectoryAPIVersion))
if err != nil {
return nil, err
}
tokenURL, err := u.Parse(fmt.Sprintf(activeDirectoryEndpointTemplate, tenantID, "token", activeDirectoryAPIVersion))
if err != nil {
return nil, err
}
deviceCodeURL, err := u.Parse(fmt.Sprintf(activeDirectoryEndpointTemplate, tenantID, "devicecode", activeDirectoryAPIVersion))
if err != nil {
return nil, err
}
return &OAuthConfig{
AuthorityEndpoint: *authorityURL,
AuthorizeEndpoint: *authorizeURL,
TokenEndpoint: *tokenURL,
DeviceCodeEndpoint: *deviceCodeURL,
}, nil
}

View file

@ -0,0 +1,44 @@
package adal
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"testing"
)
func TestNewOAuthConfig(t *testing.T) {
const testActiveDirectoryEndpoint = "https://login.test.com"
const testTenantID = "tenant-id-test"
config, err := NewOAuthConfig(testActiveDirectoryEndpoint, testTenantID)
if err != nil {
t.Fatalf("autorest/adal: Unexpected error while creating oauth configuration for tenant: %v.", err)
}
expected := "https://login.test.com/tenant-id-test/oauth2/authorize?api-version=1.0"
if config.AuthorizeEndpoint.String() != expected {
t.Fatalf("autorest/adal: Incorrect authorize url for Tenant from Environment. expected(%s). actual(%v).", expected, config.AuthorizeEndpoint)
}
expected = "https://login.test.com/tenant-id-test/oauth2/token?api-version=1.0"
if config.TokenEndpoint.String() != expected {
t.Fatalf("autorest/adal: Incorrect authorize url for Tenant from Environment. expected(%s). actual(%v).", expected, config.TokenEndpoint)
}
expected = "https://login.test.com/tenant-id-test/oauth2/devicecode?api-version=1.0"
if config.DeviceCodeEndpoint.String() != expected {
t.Fatalf("autorest/adal Incorrect devicecode url for Tenant from Environment. expected(%s). actual(%v).", expected, config.DeviceCodeEndpoint)
}
}

View file

@ -0,0 +1,242 @@
package adal
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
This file is largely based on rjw57/oauth2device's code, with the follow differences:
* scope -> resource, and only allow a single one
* receive "Message" in the DeviceCode struct and show it to users as the prompt
* azure-xplat-cli has the following behavior that this emulates:
- does not send client_secret during the token exchange
- sends resource again in the token exchange request
*/
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"strings"
"time"
)
const (
logPrefix = "autorest/adal/devicetoken:"
)
var (
// ErrDeviceGeneric represents an unknown error from the token endpoint when using device flow
ErrDeviceGeneric = fmt.Errorf("%s Error while retrieving OAuth token: Unknown Error", logPrefix)
// ErrDeviceAccessDenied represents an access denied error from the token endpoint when using device flow
ErrDeviceAccessDenied = fmt.Errorf("%s Error while retrieving OAuth token: Access Denied", logPrefix)
// ErrDeviceAuthorizationPending represents the server waiting on the user to complete the device flow
ErrDeviceAuthorizationPending = fmt.Errorf("%s Error while retrieving OAuth token: Authorization Pending", logPrefix)
// ErrDeviceCodeExpired represents the server timing out and expiring the code during device flow
ErrDeviceCodeExpired = fmt.Errorf("%s Error while retrieving OAuth token: Code Expired", logPrefix)
// ErrDeviceSlowDown represents the service telling us we're polling too often during device flow
ErrDeviceSlowDown = fmt.Errorf("%s Error while retrieving OAuth token: Slow Down", logPrefix)
// ErrDeviceCodeEmpty represents an empty device code from the device endpoint while using device flow
ErrDeviceCodeEmpty = fmt.Errorf("%s Error while retrieving device code: Device Code Empty", logPrefix)
// ErrOAuthTokenEmpty represents an empty OAuth token from the token endpoint when using device flow
ErrOAuthTokenEmpty = fmt.Errorf("%s Error while retrieving OAuth token: Token Empty", logPrefix)
errCodeSendingFails = "Error occurred while sending request for Device Authorization Code"
errCodeHandlingFails = "Error occurred while handling response from the Device Endpoint"
errTokenSendingFails = "Error occurred while sending request with device code for a token"
errTokenHandlingFails = "Error occurred while handling response from the Token Endpoint (during device flow)"
errStatusNotOK = "Error HTTP status != 200"
)
// DeviceCode is the object returned by the device auth endpoint
// It contains information to instruct the user to complete the auth flow
type DeviceCode struct {
DeviceCode *string `json:"device_code,omitempty"`
UserCode *string `json:"user_code,omitempty"`
VerificationURL *string `json:"verification_url,omitempty"`
ExpiresIn *int64 `json:"expires_in,string,omitempty"`
Interval *int64 `json:"interval,string,omitempty"`
Message *string `json:"message"` // Azure specific
Resource string // store the following, stored when initiating, used when exchanging
OAuthConfig OAuthConfig
ClientID string
}
// TokenError is the object returned by the token exchange endpoint
// when something is amiss
type TokenError struct {
Error *string `json:"error,omitempty"`
ErrorCodes []int `json:"error_codes,omitempty"`
ErrorDescription *string `json:"error_description,omitempty"`
Timestamp *string `json:"timestamp,omitempty"`
TraceID *string `json:"trace_id,omitempty"`
}
// DeviceToken is the object return by the token exchange endpoint
// It can either look like a Token or an ErrorToken, so put both here
// and check for presence of "Error" to know if we are in error state
type deviceToken struct {
Token
TokenError
}
// InitiateDeviceAuth initiates a device auth flow. It returns a DeviceCode
// that can be used with CheckForUserCompletion or WaitForUserCompletion.
func InitiateDeviceAuth(sender Sender, oauthConfig OAuthConfig, clientID, resource string) (*DeviceCode, error) {
v := url.Values{
"client_id": []string{clientID},
"resource": []string{resource},
}
s := v.Encode()
body := ioutil.NopCloser(strings.NewReader(s))
req, err := http.NewRequest(http.MethodPost, oauthConfig.DeviceCodeEndpoint.String(), body)
if err != nil {
return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeSendingFails, err.Error())
}
req.ContentLength = int64(len(s))
req.Header.Set(contentType, mimeTypeFormPost)
resp, err := sender.Do(req)
if err != nil {
return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeSendingFails, err.Error())
}
defer resp.Body.Close()
rb, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeHandlingFails, err.Error())
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeHandlingFails, errStatusNotOK)
}
if len(strings.Trim(string(rb), " ")) == 0 {
return nil, ErrDeviceCodeEmpty
}
var code DeviceCode
err = json.Unmarshal(rb, &code)
if err != nil {
return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeHandlingFails, err.Error())
}
code.ClientID = clientID
code.Resource = resource
code.OAuthConfig = oauthConfig
return &code, nil
}
// CheckForUserCompletion takes a DeviceCode and checks with the Azure AD OAuth endpoint
// to see if the device flow has: been completed, timed out, or otherwise failed
func CheckForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) {
v := url.Values{
"client_id": []string{code.ClientID},
"code": []string{*code.DeviceCode},
"grant_type": []string{OAuthGrantTypeDeviceCode},
"resource": []string{code.Resource},
}
s := v.Encode()
body := ioutil.NopCloser(strings.NewReader(s))
req, err := http.NewRequest(http.MethodPost, code.OAuthConfig.TokenEndpoint.String(), body)
if err != nil {
return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenSendingFails, err.Error())
}
req.ContentLength = int64(len(s))
req.Header.Set(contentType, mimeTypeFormPost)
resp, err := sender.Do(req)
if err != nil {
return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenSendingFails, err.Error())
}
defer resp.Body.Close()
rb, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenHandlingFails, err.Error())
}
if resp.StatusCode != http.StatusOK && len(strings.Trim(string(rb), " ")) == 0 {
return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenHandlingFails, errStatusNotOK)
}
if len(strings.Trim(string(rb), " ")) == 0 {
return nil, ErrOAuthTokenEmpty
}
var token deviceToken
err = json.Unmarshal(rb, &token)
if err != nil {
return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenHandlingFails, err.Error())
}
if token.Error == nil {
return &token.Token, nil
}
switch *token.Error {
case "authorization_pending":
return nil, ErrDeviceAuthorizationPending
case "slow_down":
return nil, ErrDeviceSlowDown
case "access_denied":
return nil, ErrDeviceAccessDenied
case "code_expired":
return nil, ErrDeviceCodeExpired
default:
return nil, ErrDeviceGeneric
}
}
// WaitForUserCompletion calls CheckForUserCompletion repeatedly until a token is granted or an error state occurs.
// This prevents the user from looping and checking against 'ErrDeviceAuthorizationPending'.
func WaitForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) {
intervalDuration := time.Duration(*code.Interval) * time.Second
waitDuration := intervalDuration
for {
token, err := CheckForUserCompletion(sender, code)
if err == nil {
return token, nil
}
switch err {
case ErrDeviceSlowDown:
waitDuration += waitDuration
case ErrDeviceAuthorizationPending:
// noop
default: // everything else is "fatal" to us
return nil, err
}
if waitDuration > (intervalDuration * 3) {
return nil, fmt.Errorf("%s Error waiting for user to complete device flow. Server told us to slow_down too much", logPrefix)
}
time.Sleep(waitDuration)
}
}

View file

@ -0,0 +1,330 @@
package adal
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"encoding/json"
"fmt"
"net/http"
"strings"
"testing"
"github.com/Azure/go-autorest/autorest/mocks"
)
const (
TestResource = "SomeResource"
TestClientID = "SomeClientID"
TestTenantID = "SomeTenantID"
TestActiveDirectoryEndpoint = "https://login.test.com/"
)
var (
testOAuthConfig, _ = NewOAuthConfig(TestActiveDirectoryEndpoint, TestTenantID)
TestOAuthConfig = *testOAuthConfig
)
const MockDeviceCodeResponse = `
{
"device_code": "10000-40-1234567890",
"user_code": "ABCDEF",
"verification_url": "http://aka.ms/deviceauth",
"expires_in": "900",
"interval": "0"
}
`
const MockDeviceTokenResponse = `{
"access_token": "accessToken",
"refresh_token": "refreshToken",
"expires_in": "1000",
"expires_on": "2000",
"not_before": "3000",
"resource": "resource",
"token_type": "type"
}
`
func TestDeviceCodeIncludesResource(t *testing.T) {
sender := mocks.NewSender()
sender.AppendResponse(mocks.NewResponseWithContent(MockDeviceCodeResponse))
code, err := InitiateDeviceAuth(sender, TestOAuthConfig, TestClientID, TestResource)
if err != nil {
t.Fatalf("adal: unexpected error initiating device auth")
}
if code.Resource != TestResource {
t.Fatalf("adal: InitiateDeviceAuth failed to stash the resource in the DeviceCode struct")
}
}
func TestDeviceCodeReturnsErrorIfSendingFails(t *testing.T) {
sender := mocks.NewSender()
sender.SetError(fmt.Errorf("this is an error"))
_, err := InitiateDeviceAuth(sender, TestOAuthConfig, TestClientID, TestResource)
if err == nil || !strings.Contains(err.Error(), errCodeSendingFails) {
t.Fatalf("adal: failed to get correct error expected(%s) actual(%s)", errCodeSendingFails, err.Error())
}
}
func TestDeviceCodeReturnsErrorIfBadRequest(t *testing.T) {
sender := mocks.NewSender()
body := mocks.NewBody("doesn't matter")
sender.AppendResponse(mocks.NewResponseWithBodyAndStatus(body, http.StatusBadRequest, "Bad Request"))
_, err := InitiateDeviceAuth(sender, TestOAuthConfig, TestClientID, TestResource)
if err == nil || !strings.Contains(err.Error(), errCodeHandlingFails) {
t.Fatalf("adal: failed to get correct error expected(%s) actual(%s)", errCodeHandlingFails, err.Error())
}
if body.IsOpen() {
t.Fatalf("response body was left open!")
}
}
func TestDeviceCodeReturnsErrorIfCannotDeserializeDeviceCode(t *testing.T) {
gibberishJSON := strings.Replace(MockDeviceCodeResponse, "expires_in", "\":, :gibberish", -1)
sender := mocks.NewSender()
body := mocks.NewBody(gibberishJSON)
sender.AppendResponse(mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK"))
_, err := InitiateDeviceAuth(sender, TestOAuthConfig, TestClientID, TestResource)
if err == nil || !strings.Contains(err.Error(), errCodeHandlingFails) {
t.Fatalf("adal: failed to get correct error expected(%s) actual(%s)", errCodeHandlingFails, err.Error())
}
if body.IsOpen() {
t.Fatalf("response body was left open!")
}
}
func TestDeviceCodeReturnsErrorIfEmptyDeviceCode(t *testing.T) {
sender := mocks.NewSender()
body := mocks.NewBody("")
sender.AppendResponse(mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK"))
_, err := InitiateDeviceAuth(sender, TestOAuthConfig, TestClientID, TestResource)
if err != ErrDeviceCodeEmpty {
t.Fatalf("adal: failed to get correct error expected(%s) actual(%s)", ErrDeviceCodeEmpty, err.Error())
}
if body.IsOpen() {
t.Fatalf("response body was left open!")
}
}
func deviceCode() *DeviceCode {
var deviceCode DeviceCode
_ = json.Unmarshal([]byte(MockDeviceCodeResponse), &deviceCode)
deviceCode.Resource = TestResource
deviceCode.ClientID = TestClientID
return &deviceCode
}
func TestDeviceTokenReturns(t *testing.T) {
sender := mocks.NewSender()
body := mocks.NewBody(MockDeviceTokenResponse)
sender.AppendResponse(mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK"))
_, err := WaitForUserCompletion(sender, deviceCode())
if err != nil {
t.Fatalf("adal: got error unexpectedly")
}
if body.IsOpen() {
t.Fatalf("response body was left open!")
}
}
func TestDeviceTokenReturnsErrorIfSendingFails(t *testing.T) {
sender := mocks.NewSender()
sender.SetError(fmt.Errorf("this is an error"))
_, err := WaitForUserCompletion(sender, deviceCode())
if err == nil || !strings.Contains(err.Error(), errTokenSendingFails) {
t.Fatalf("adal: failed to get correct error expected(%s) actual(%s)", errTokenSendingFails, err.Error())
}
}
func TestDeviceTokenReturnsErrorIfServerError(t *testing.T) {
sender := mocks.NewSender()
body := mocks.NewBody("")
sender.AppendResponse(mocks.NewResponseWithBodyAndStatus(body, http.StatusInternalServerError, "Internal Server Error"))
_, err := WaitForUserCompletion(sender, deviceCode())
if err == nil || !strings.Contains(err.Error(), errTokenHandlingFails) {
t.Fatalf("adal: failed to get correct error expected(%s) actual(%s)", errTokenHandlingFails, err.Error())
}
if body.IsOpen() {
t.Fatalf("response body was left open!")
}
}
func TestDeviceTokenReturnsErrorIfCannotDeserializeDeviceToken(t *testing.T) {
gibberishJSON := strings.Replace(MockDeviceTokenResponse, "expires_in", ";:\"gibberish", -1)
sender := mocks.NewSender()
body := mocks.NewBody(gibberishJSON)
sender.AppendResponse(mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK"))
_, err := WaitForUserCompletion(sender, deviceCode())
if err == nil || !strings.Contains(err.Error(), errTokenHandlingFails) {
t.Fatalf("adal: failed to get correct error expected(%s) actual(%s)", errTokenHandlingFails, err.Error())
}
if body.IsOpen() {
t.Fatalf("response body was left open!")
}
}
func errorDeviceTokenResponse(message string) string {
return `{ "error": "` + message + `" }`
}
func TestDeviceTokenReturnsErrorIfAuthorizationPending(t *testing.T) {
sender := mocks.NewSender()
body := mocks.NewBody(errorDeviceTokenResponse("authorization_pending"))
sender.AppendResponse(mocks.NewResponseWithBodyAndStatus(body, http.StatusBadRequest, "Bad Request"))
_, err := CheckForUserCompletion(sender, deviceCode())
if err != ErrDeviceAuthorizationPending {
t.Fatalf("!!!")
}
if body.IsOpen() {
t.Fatalf("response body was left open!")
}
}
func TestDeviceTokenReturnsErrorIfSlowDown(t *testing.T) {
sender := mocks.NewSender()
body := mocks.NewBody(errorDeviceTokenResponse("slow_down"))
sender.AppendResponse(mocks.NewResponseWithBodyAndStatus(body, http.StatusBadRequest, "Bad Request"))
_, err := CheckForUserCompletion(sender, deviceCode())
if err != ErrDeviceSlowDown {
t.Fatalf("!!!")
}
if body.IsOpen() {
t.Fatalf("response body was left open!")
}
}
type deviceTokenSender struct {
errorString string
attempts int
}
func newDeviceTokenSender(deviceErrorString string) *deviceTokenSender {
return &deviceTokenSender{errorString: deviceErrorString, attempts: 0}
}
func (s *deviceTokenSender) Do(req *http.Request) (*http.Response, error) {
var resp *http.Response
if s.attempts < 1 {
s.attempts++
resp = mocks.NewResponseWithContent(errorDeviceTokenResponse(s.errorString))
} else {
resp = mocks.NewResponseWithContent(MockDeviceTokenResponse)
}
return resp, nil
}
// since the above only exercise CheckForUserCompletion, we repeat the test here,
// but with the intent of showing that WaitForUserCompletion loops properly.
func TestDeviceTokenSucceedsWithIntermediateAuthPending(t *testing.T) {
sender := newDeviceTokenSender("authorization_pending")
_, err := WaitForUserCompletion(sender, deviceCode())
if err != nil {
t.Fatalf("unexpected error occurred")
}
}
// same as above but with SlowDown now
func TestDeviceTokenSucceedsWithIntermediateSlowDown(t *testing.T) {
sender := newDeviceTokenSender("slow_down")
_, err := WaitForUserCompletion(sender, deviceCode())
if err != nil {
t.Fatalf("unexpected error occurred")
}
}
func TestDeviceTokenReturnsErrorIfAccessDenied(t *testing.T) {
sender := mocks.NewSender()
body := mocks.NewBody(errorDeviceTokenResponse("access_denied"))
sender.AppendResponse(mocks.NewResponseWithBodyAndStatus(body, http.StatusBadRequest, "Bad Request"))
_, err := WaitForUserCompletion(sender, deviceCode())
if err != ErrDeviceAccessDenied {
t.Fatalf("adal: got wrong error expected(%s) actual(%s)", ErrDeviceAccessDenied.Error(), err.Error())
}
if body.IsOpen() {
t.Fatalf("response body was left open!")
}
}
func TestDeviceTokenReturnsErrorIfCodeExpired(t *testing.T) {
sender := mocks.NewSender()
body := mocks.NewBody(errorDeviceTokenResponse("code_expired"))
sender.AppendResponse(mocks.NewResponseWithBodyAndStatus(body, http.StatusBadRequest, "Bad Request"))
_, err := WaitForUserCompletion(sender, deviceCode())
if err != ErrDeviceCodeExpired {
t.Fatalf("adal: got wrong error expected(%s) actual(%s)", ErrDeviceCodeExpired.Error(), err.Error())
}
if body.IsOpen() {
t.Fatalf("response body was left open!")
}
}
func TestDeviceTokenReturnsErrorForUnknownError(t *testing.T) {
sender := mocks.NewSender()
body := mocks.NewBody(errorDeviceTokenResponse("unknown_error"))
sender.AppendResponse(mocks.NewResponseWithBodyAndStatus(body, http.StatusBadRequest, "Bad Request"))
_, err := WaitForUserCompletion(sender, deviceCode())
if err == nil {
t.Fatalf("failed to get error")
}
if err != ErrDeviceGeneric {
t.Fatalf("adal: got wrong error expected(%s) actual(%s)", ErrDeviceGeneric.Error(), err.Error())
}
if body.IsOpen() {
t.Fatalf("response body was left open!")
}
}
func TestDeviceTokenReturnsErrorIfTokenEmptyAndStatusOK(t *testing.T) {
sender := mocks.NewSender()
body := mocks.NewBody("")
sender.AppendResponse(mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK"))
_, err := WaitForUserCompletion(sender, deviceCode())
if err != ErrOAuthTokenEmpty {
t.Fatalf("adal: got wrong error expected(%s) actual(%s)", ErrOAuthTokenEmpty.Error(), err.Error())
}
if body.IsOpen() {
t.Fatalf("response body was left open!")
}
}

View file

@ -0,0 +1,20 @@
// +build !windows
package adal
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// msiPath is the path to the MSI Extension settings file (to discover the endpoint)
var msiPath = "/var/lib/waagent/ManagedIdentity-Settings"

View file

@ -0,0 +1,25 @@
// +build windows
package adal
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"os"
"strings"
)
// msiPath is the path to the MSI Extension settings file (to discover the endpoint)
var msiPath = strings.Join([]string{os.Getenv("SystemDrive"), "WindowsAzure/Config/ManagedIdentity-Settings"}, "/")

View file

@ -0,0 +1,73 @@
package adal
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"encoding/json"
"fmt"
"io/ioutil"
"os"
"path/filepath"
)
// LoadToken restores a Token object from a file located at 'path'.
func LoadToken(path string) (*Token, error) {
file, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("failed to open file (%s) while loading token: %v", path, err)
}
defer file.Close()
var token Token
dec := json.NewDecoder(file)
if err = dec.Decode(&token); err != nil {
return nil, fmt.Errorf("failed to decode contents of file (%s) into Token representation: %v", path, err)
}
return &token, nil
}
// SaveToken persists an oauth token at the given location on disk.
// It moves the new file into place so it can safely be used to replace an existing file
// that maybe accessed by multiple processes.
func SaveToken(path string, mode os.FileMode, token Token) error {
dir := filepath.Dir(path)
err := os.MkdirAll(dir, os.ModePerm)
if err != nil {
return fmt.Errorf("failed to create directory (%s) to store token in: %v", dir, err)
}
newFile, err := ioutil.TempFile(dir, "token")
if err != nil {
return fmt.Errorf("failed to create the temp file to write the token: %v", err)
}
tempPath := newFile.Name()
if err := json.NewEncoder(newFile).Encode(token); err != nil {
return fmt.Errorf("failed to encode token to file (%s) while saving token: %v", tempPath, err)
}
if err := newFile.Close(); err != nil {
return fmt.Errorf("failed to close temp file %s: %v", tempPath, err)
}
// Atomic replace to avoid multi-writer file corruptions
if err := os.Rename(tempPath, path); err != nil {
return fmt.Errorf("failed to move temporary token to desired output location. src=%s dst=%s: %v", tempPath, path, err)
}
if err := os.Chmod(path, mode); err != nil {
return fmt.Errorf("failed to chmod the token file %s: %v", path, err)
}
return nil
}

View file

@ -0,0 +1,171 @@
package adal
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"encoding/json"
"io/ioutil"
"os"
"path"
"reflect"
"runtime"
"strings"
"testing"
)
const MockTokenJSON string = `{
"access_token": "accessToken",
"refresh_token": "refreshToken",
"expires_in": "1000",
"expires_on": "2000",
"not_before": "3000",
"resource": "resource",
"token_type": "type"
}`
var TestToken = Token{
AccessToken: "accessToken",
RefreshToken: "refreshToken",
ExpiresIn: "1000",
ExpiresOn: "2000",
NotBefore: "3000",
Resource: "resource",
Type: "type",
}
func writeTestTokenFile(t *testing.T, suffix string, contents string) *os.File {
f, err := ioutil.TempFile(os.TempDir(), suffix)
if err != nil {
t.Fatalf("azure: unexpected error when creating temp file: %v", err)
}
defer f.Close()
_, err = f.Write([]byte(contents))
if err != nil {
t.Fatalf("azure: unexpected error when writing temp test file: %v", err)
}
return f
}
func TestLoadToken(t *testing.T) {
f := writeTestTokenFile(t, "testloadtoken", MockTokenJSON)
defer os.Remove(f.Name())
expectedToken := TestToken
actualToken, err := LoadToken(f.Name())
if err != nil {
t.Fatalf("azure: unexpected error loading token from file: %v", err)
}
if *actualToken != expectedToken {
t.Fatalf("azure: failed to decode properly expected(%v) actual(%v)", expectedToken, *actualToken)
}
// test that LoadToken closes the file properly
err = SaveToken(f.Name(), 0600, *actualToken)
if err != nil {
t.Fatalf("azure: could not save token after LoadToken: %v", err)
}
}
func TestLoadTokenFailsBadPath(t *testing.T) {
_, err := LoadToken("/tmp/this_file_should_never_exist_really")
expectedSubstring := "failed to open file"
if err == nil || !strings.Contains(err.Error(), expectedSubstring) {
t.Fatalf("azure: failed to get correct error expected(%s) actual(%s)", expectedSubstring, err.Error())
}
}
func TestLoadTokenFailsBadJson(t *testing.T) {
gibberishJSON := strings.Replace(MockTokenJSON, "expires_on", ";:\"gibberish", -1)
f := writeTestTokenFile(t, "testloadtokenfailsbadjson", gibberishJSON)
defer os.Remove(f.Name())
_, err := LoadToken(f.Name())
expectedSubstring := "failed to decode contents of file"
if err == nil || !strings.Contains(err.Error(), expectedSubstring) {
t.Fatalf("azure: failed to get correct error expected(%s) actual(%s)", expectedSubstring, err.Error())
}
}
func token() *Token {
var token Token
json.Unmarshal([]byte(MockTokenJSON), &token)
return &token
}
func TestSaveToken(t *testing.T) {
f, err := ioutil.TempFile("", "testloadtoken")
if err != nil {
t.Fatalf("azure: unexpected error when creating temp file: %v", err)
}
defer os.Remove(f.Name())
f.Close()
mode := os.ModePerm & 0642
err = SaveToken(f.Name(), mode, *token())
if err != nil {
t.Fatalf("azure: unexpected error saving token to file: %v", err)
}
fi, err := os.Stat(f.Name()) // open a new stat as held ones are not fresh
if err != nil {
t.Fatalf("azure: stat failed: %v", err)
}
if runtime.GOOS != "windows" { // permissions don't work on Windows
if perm := fi.Mode().Perm(); perm != mode {
t.Fatalf("azure: wrong file perm. got:%s; expected:%s file :%s", perm, mode, f.Name())
}
}
var actualToken Token
var expectedToken Token
json.Unmarshal([]byte(MockTokenJSON), expectedToken)
contents, err := ioutil.ReadFile(f.Name())
if err != nil {
t.Fatal("!!")
}
json.Unmarshal(contents, actualToken)
if !reflect.DeepEqual(actualToken, expectedToken) {
t.Fatal("azure: token was not serialized correctly")
}
}
func TestSaveTokenFailsNoPermission(t *testing.T) {
pathWhereWeShouldntHavePermission := "/usr/thiswontwork/atall"
if runtime.GOOS == "windows" {
pathWhereWeShouldntHavePermission = path.Join(os.Getenv("windir"), "system32\\mytokendir\\mytoken")
}
err := SaveToken(pathWhereWeShouldntHavePermission, 0644, *token())
expectedSubstring := "failed to create directory"
if err == nil || !strings.Contains(err.Error(), expectedSubstring) {
t.Fatalf("azure: failed to get correct error expected(%s) actual(%v)", expectedSubstring, err)
}
}
func TestSaveTokenFailsCantCreate(t *testing.T) {
tokenPath := "/thiswontwork"
if runtime.GOOS == "windows" {
tokenPath = path.Join(os.Getenv("windir"), "system32")
}
err := SaveToken(tokenPath, 0644, *token())
expectedSubstring := "failed to create the temp file to write the token"
if err == nil || !strings.Contains(err.Error(), expectedSubstring) {
t.Fatalf("azure: failed to get correct error expected(%s) actual(%v)", expectedSubstring, err)
}
}

View file

@ -0,0 +1,60 @@
package adal
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"net/http"
)
const (
contentType = "Content-Type"
mimeTypeFormPost = "application/x-www-form-urlencoded"
)
// Sender is the interface that wraps the Do method to send HTTP requests.
//
// The standard http.Client conforms to this interface.
type Sender interface {
Do(*http.Request) (*http.Response, error)
}
// SenderFunc is a method that implements the Sender interface.
type SenderFunc func(*http.Request) (*http.Response, error)
// Do implements the Sender interface on SenderFunc.
func (sf SenderFunc) Do(r *http.Request) (*http.Response, error) {
return sf(r)
}
// SendDecorator takes and possibily decorates, by wrapping, a Sender. Decorators may affect the
// http.Request and pass it along or, first, pass the http.Request along then react to the
// http.Response result.
type SendDecorator func(Sender) Sender
// CreateSender creates, decorates, and returns, as a Sender, the default http.Client.
func CreateSender(decorators ...SendDecorator) Sender {
return DecorateSender(&http.Client{}, decorators...)
}
// DecorateSender accepts a Sender and a, possibly empty, set of SendDecorators, which is applies to
// the Sender. Decorators are applied in the order received, but their affect upon the request
// depends on whether they are a pre-decorator (change the http.Request and then pass it along) or a
// post-decorator (pass the http.Request along and react to the results in http.Response).
func DecorateSender(s Sender, decorators ...SendDecorator) Sender {
for _, decorate := range decorators {
s = decorate(s)
}
return s
}

View file

@ -0,0 +1,427 @@
package adal
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"crypto/rand"
"crypto/rsa"
"crypto/sha1"
"crypto/x509"
"encoding/base64"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/Azure/go-autorest/autorest/date"
"github.com/dgrijalva/jwt-go"
)
const (
defaultRefresh = 5 * time.Minute
// OAuthGrantTypeDeviceCode is the "grant_type" identifier used in device flow
OAuthGrantTypeDeviceCode = "device_code"
// OAuthGrantTypeClientCredentials is the "grant_type" identifier used in credential flows
OAuthGrantTypeClientCredentials = "client_credentials"
// OAuthGrantTypeRefreshToken is the "grant_type" identifier used in refresh token flows
OAuthGrantTypeRefreshToken = "refresh_token"
// metadataHeader is the header required by MSI extension
metadataHeader = "Metadata"
)
// OAuthTokenProvider is an interface which should be implemented by an access token retriever
type OAuthTokenProvider interface {
OAuthToken() string
}
// Refresher is an interface for token refresh functionality
type Refresher interface {
Refresh() error
RefreshExchange(resource string) error
EnsureFresh() error
}
// TokenRefreshCallback is the type representing callbacks that will be called after
// a successful token refresh
type TokenRefreshCallback func(Token) error
// Token encapsulates the access token used to authorize Azure requests.
type Token struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn string `json:"expires_in"`
ExpiresOn string `json:"expires_on"`
NotBefore string `json:"not_before"`
Resource string `json:"resource"`
Type string `json:"token_type"`
}
// Expires returns the time.Time when the Token expires.
func (t Token) Expires() time.Time {
s, err := strconv.Atoi(t.ExpiresOn)
if err != nil {
s = -3600
}
expiration := date.NewUnixTimeFromSeconds(float64(s))
return time.Time(expiration).UTC()
}
// IsExpired returns true if the Token is expired, false otherwise.
func (t Token) IsExpired() bool {
return t.WillExpireIn(0)
}
// WillExpireIn returns true if the Token will expire after the passed time.Duration interval
// from now, false otherwise.
func (t Token) WillExpireIn(d time.Duration) bool {
return !t.Expires().After(time.Now().Add(d))
}
//OAuthToken return the current access token
func (t *Token) OAuthToken() string {
return t.AccessToken
}
// ServicePrincipalNoSecret represents a secret type that contains no secret
// meaning it is not valid for fetching a fresh token. This is used by Manual
type ServicePrincipalNoSecret struct {
}
// SetAuthenticationValues is a method of the interface ServicePrincipalSecret
// It only returns an error for the ServicePrincipalNoSecret type
func (noSecret *ServicePrincipalNoSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
return fmt.Errorf("Manually created ServicePrincipalToken does not contain secret material to retrieve a new access token")
}
// ServicePrincipalSecret is an interface that allows various secret mechanism to fill the form
// that is submitted when acquiring an oAuth token.
type ServicePrincipalSecret interface {
SetAuthenticationValues(spt *ServicePrincipalToken, values *url.Values) error
}
// ServicePrincipalTokenSecret implements ServicePrincipalSecret for client_secret type authorization.
type ServicePrincipalTokenSecret struct {
ClientSecret string
}
// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
// It will populate the form submitted during oAuth Token Acquisition using the client_secret.
func (tokenSecret *ServicePrincipalTokenSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
v.Set("client_secret", tokenSecret.ClientSecret)
return nil
}
// ServicePrincipalCertificateSecret implements ServicePrincipalSecret for generic RSA cert auth with signed JWTs.
type ServicePrincipalCertificateSecret struct {
Certificate *x509.Certificate
PrivateKey *rsa.PrivateKey
}
// ServicePrincipalMSISecret implements ServicePrincipalSecret for machines running the MSI Extension.
type ServicePrincipalMSISecret struct {
}
// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
func (msiSecret *ServicePrincipalMSISecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
return nil
}
// SignJwt returns the JWT signed with the certificate's private key.
func (secret *ServicePrincipalCertificateSecret) SignJwt(spt *ServicePrincipalToken) (string, error) {
hasher := sha1.New()
_, err := hasher.Write(secret.Certificate.Raw)
if err != nil {
return "", err
}
thumbprint := base64.URLEncoding.EncodeToString(hasher.Sum(nil))
// The jti (JWT ID) claim provides a unique identifier for the JWT.
jti := make([]byte, 20)
_, err = rand.Read(jti)
if err != nil {
return "", err
}
token := jwt.New(jwt.SigningMethodRS256)
token.Header["x5t"] = thumbprint
token.Claims = jwt.MapClaims{
"aud": spt.oauthConfig.TokenEndpoint.String(),
"iss": spt.clientID,
"sub": spt.clientID,
"jti": base64.URLEncoding.EncodeToString(jti),
"nbf": time.Now().Unix(),
"exp": time.Now().Add(time.Hour * 24).Unix(),
}
signedString, err := token.SignedString(secret.PrivateKey)
return signedString, err
}
// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
// It will populate the form submitted during oAuth Token Acquisition using a JWT signed with a certificate.
func (secret *ServicePrincipalCertificateSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
jwt, err := secret.SignJwt(spt)
if err != nil {
return err
}
v.Set("client_assertion", jwt)
v.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
return nil
}
// ServicePrincipalToken encapsulates a Token created for a Service Principal.
type ServicePrincipalToken struct {
Token
secret ServicePrincipalSecret
oauthConfig OAuthConfig
clientID string
resource string
autoRefresh bool
refreshWithin time.Duration
sender Sender
refreshCallbacks []TokenRefreshCallback
}
// NewServicePrincipalTokenWithSecret create a ServicePrincipalToken using the supplied ServicePrincipalSecret implementation.
func NewServicePrincipalTokenWithSecret(oauthConfig OAuthConfig, id string, resource string, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
spt := &ServicePrincipalToken{
oauthConfig: oauthConfig,
secret: secret,
clientID: id,
resource: resource,
autoRefresh: true,
refreshWithin: defaultRefresh,
sender: &http.Client{},
refreshCallbacks: callbacks,
}
return spt, nil
}
// NewServicePrincipalTokenFromManualToken creates a ServicePrincipalToken using the supplied token
func NewServicePrincipalTokenFromManualToken(oauthConfig OAuthConfig, clientID string, resource string, token Token, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
spt, err := NewServicePrincipalTokenWithSecret(
oauthConfig,
clientID,
resource,
&ServicePrincipalNoSecret{},
callbacks...)
if err != nil {
return nil, err
}
spt.Token = token
return spt, nil
}
// NewServicePrincipalToken creates a ServicePrincipalToken from the supplied Service Principal
// credentials scoped to the named resource.
func NewServicePrincipalToken(oauthConfig OAuthConfig, clientID string, secret string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
return NewServicePrincipalTokenWithSecret(
oauthConfig,
clientID,
resource,
&ServicePrincipalTokenSecret{
ClientSecret: secret,
},
callbacks...,
)
}
// NewServicePrincipalTokenFromCertificate create a ServicePrincipalToken from the supplied pkcs12 bytes.
func NewServicePrincipalTokenFromCertificate(oauthConfig OAuthConfig, clientID string, certificate *x509.Certificate, privateKey *rsa.PrivateKey, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
return NewServicePrincipalTokenWithSecret(
oauthConfig,
clientID,
resource,
&ServicePrincipalCertificateSecret{
PrivateKey: privateKey,
Certificate: certificate,
},
callbacks...,
)
}
// GetMSIVMEndpoint gets the MSI endpoint on Virtual Machines.
func GetMSIVMEndpoint() (string, error) {
return getMSIVMEndpoint(msiPath)
}
func getMSIVMEndpoint(path string) (string, error) {
// Read MSI settings
bytes, err := ioutil.ReadFile(path)
if err != nil {
return "", err
}
msiSettings := struct {
URL string `json:"url"`
}{}
err = json.Unmarshal(bytes, &msiSettings)
if err != nil {
return "", err
}
return msiSettings.URL, nil
}
// NewServicePrincipalTokenFromMSI creates a ServicePrincipalToken via the MSI VM Extension.
func NewServicePrincipalTokenFromMSI(msiEndpoint, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
// We set the oauth config token endpoint to be MSI's endpoint
msiEndpointURL, err := url.Parse(msiEndpoint)
if err != nil {
return nil, err
}
oauthConfig, err := NewOAuthConfig(msiEndpointURL.String(), "")
if err != nil {
return nil, err
}
spt := &ServicePrincipalToken{
oauthConfig: *oauthConfig,
secret: &ServicePrincipalMSISecret{},
resource: resource,
autoRefresh: true,
refreshWithin: defaultRefresh,
sender: &http.Client{},
refreshCallbacks: callbacks,
}
return spt, nil
}
// EnsureFresh will refresh the token if it will expire within the refresh window (as set by
// RefreshWithin) and autoRefresh flag is on.
func (spt *ServicePrincipalToken) EnsureFresh() error {
if spt.autoRefresh && spt.WillExpireIn(spt.refreshWithin) {
return spt.Refresh()
}
return nil
}
// InvokeRefreshCallbacks calls any TokenRefreshCallbacks that were added to the SPT during initialization
func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error {
if spt.refreshCallbacks != nil {
for _, callback := range spt.refreshCallbacks {
err := callback(spt.Token)
if err != nil {
return fmt.Errorf("adal: TokenRefreshCallback handler failed. Error = '%v'", err)
}
}
}
return nil
}
// Refresh obtains a fresh token for the Service Principal.
func (spt *ServicePrincipalToken) Refresh() error {
return spt.refreshInternal(spt.resource)
}
// RefreshExchange refreshes the token, but for a different resource.
func (spt *ServicePrincipalToken) RefreshExchange(resource string) error {
return spt.refreshInternal(resource)
}
func (spt *ServicePrincipalToken) refreshInternal(resource string) error {
v := url.Values{}
v.Set("client_id", spt.clientID)
v.Set("resource", resource)
if spt.RefreshToken != "" {
v.Set("grant_type", OAuthGrantTypeRefreshToken)
v.Set("refresh_token", spt.RefreshToken)
} else {
v.Set("grant_type", OAuthGrantTypeClientCredentials)
err := spt.secret.SetAuthenticationValues(spt, &v)
if err != nil {
return err
}
}
s := v.Encode()
body := ioutil.NopCloser(strings.NewReader(s))
req, err := http.NewRequest(http.MethodPost, spt.oauthConfig.TokenEndpoint.String(), body)
if err != nil {
return fmt.Errorf("adal: Failed to build the refresh request. Error = '%v'", err)
}
req.ContentLength = int64(len(s))
req.Header.Set(contentType, mimeTypeFormPost)
if _, ok := spt.secret.(*ServicePrincipalMSISecret); ok {
req.Header.Set(metadataHeader, "true")
}
resp, err := spt.sender.Do(req)
if err != nil {
return fmt.Errorf("adal: Failed to execute the refresh request. Error = '%v'", err)
}
defer resp.Body.Close()
rb, err := ioutil.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
if err != nil {
return fmt.Errorf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body", resp.StatusCode)
}
return fmt.Errorf("adal: Refresh request failed. Status Code = '%d'. Response body: %s", resp.StatusCode, string(rb))
}
if err != nil {
return fmt.Errorf("adal: Failed to read a new service principal token during refresh. Error = '%v'", err)
}
if len(strings.Trim(string(rb), " ")) == 0 {
return fmt.Errorf("adal: Empty service principal token received during refresh")
}
var token Token
err = json.Unmarshal(rb, &token)
if err != nil {
return fmt.Errorf("adal: Failed to unmarshal the service principal token during refresh. Error = '%v' JSON = '%s'", err, string(rb))
}
spt.Token = token
return spt.InvokeRefreshCallbacks(token)
}
// SetAutoRefresh enables or disables automatic refreshing of stale tokens.
func (spt *ServicePrincipalToken) SetAutoRefresh(autoRefresh bool) {
spt.autoRefresh = autoRefresh
}
// SetRefreshWithin sets the interval within which if the token will expire, EnsureFresh will
// refresh the token.
func (spt *ServicePrincipalToken) SetRefreshWithin(d time.Duration) {
spt.refreshWithin = d
return
}
// SetSender sets the http.Client used when obtaining the Service Principal token. An
// undecorated http.Client is used by default.
func (spt *ServicePrincipalToken) SetSender(s Sender) { spt.sender = s }

View file

@ -0,0 +1,654 @@
package adal
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"io/ioutil"
"math/big"
"net/http"
"net/url"
"os"
"reflect"
"strconv"
"strings"
"testing"
"time"
"github.com/Azure/go-autorest/autorest/date"
"github.com/Azure/go-autorest/autorest/mocks"
)
const (
defaultFormData = "client_id=id&client_secret=secret&grant_type=client_credentials&resource=resource"
defaultManualFormData = "client_id=id&grant_type=refresh_token&refresh_token=refreshtoken&resource=resource"
)
func TestTokenExpires(t *testing.T) {
tt := time.Now().Add(5 * time.Second)
tk := newTokenExpiresAt(tt)
if tk.Expires().Equal(tt) {
t.Fatalf("adal: Token#Expires miscalculated expiration time -- received %v, expected %v", tk.Expires(), tt)
}
}
func TestTokenIsExpired(t *testing.T) {
tk := newTokenExpiresAt(time.Now().Add(-5 * time.Second))
if !tk.IsExpired() {
t.Fatalf("adal: Token#IsExpired failed to mark a stale token as expired -- now %v, token expires at %v",
time.Now().UTC(), tk.Expires())
}
}
func TestTokenIsExpiredUninitialized(t *testing.T) {
tk := &Token{}
if !tk.IsExpired() {
t.Fatalf("adal: An uninitialized Token failed to mark itself as expired (expiration time %v)", tk.Expires())
}
}
func TestTokenIsNoExpired(t *testing.T) {
tk := newTokenExpiresAt(time.Now().Add(1000 * time.Second))
if tk.IsExpired() {
t.Fatalf("adal: Token marked a fresh token as expired -- now %v, token expires at %v", time.Now().UTC(), tk.Expires())
}
}
func TestTokenWillExpireIn(t *testing.T) {
d := 5 * time.Second
tk := newTokenExpiresIn(d)
if !tk.WillExpireIn(d) {
t.Fatal("adal: Token#WillExpireIn mismeasured expiration time")
}
}
func TestServicePrincipalTokenSetAutoRefresh(t *testing.T) {
spt := newServicePrincipalToken()
if !spt.autoRefresh {
t.Fatal("adal: ServicePrincipalToken did not default to automatic token refreshing")
}
spt.SetAutoRefresh(false)
if spt.autoRefresh {
t.Fatal("adal: ServicePrincipalToken#SetAutoRefresh did not disable automatic token refreshing")
}
}
func TestServicePrincipalTokenSetRefreshWithin(t *testing.T) {
spt := newServicePrincipalToken()
if spt.refreshWithin != defaultRefresh {
t.Fatal("adal: ServicePrincipalToken did not correctly set the default refresh interval")
}
spt.SetRefreshWithin(2 * defaultRefresh)
if spt.refreshWithin != 2*defaultRefresh {
t.Fatal("adal: ServicePrincipalToken#SetRefreshWithin did not set the refresh interval")
}
}
func TestServicePrincipalTokenSetSender(t *testing.T) {
spt := newServicePrincipalToken()
c := &http.Client{}
spt.SetSender(c)
if !reflect.DeepEqual(c, spt.sender) {
t.Fatal("adal: ServicePrincipalToken#SetSender did not set the sender")
}
}
func TestServicePrincipalTokenRefreshUsesPOST(t *testing.T) {
spt := newServicePrincipalToken()
body := mocks.NewBody(newTokenJSON("test", "test"))
resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
if r.Method != "POST" {
t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set HTTP method -- expected %v, received %v", "POST", r.Method)
}
return resp, nil
})
}
})())
spt.SetSender(s)
err := spt.Refresh()
if err != nil {
t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
}
if body.IsOpen() {
t.Fatalf("the response was not closed!")
}
}
func TestServicePrincipalTokenFromMSIRefreshUsesPOST(t *testing.T) {
resource := "https://resource"
cb := func(token Token) error { return nil }
spt, err := NewServicePrincipalTokenFromMSI("http://msiendpoint/", resource, cb)
if err != nil {
t.Fatalf("Failed to get MSI SPT: %v", err)
}
body := mocks.NewBody(newTokenJSON("test", "test"))
resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
if r.Method != "POST" {
t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set HTTP method -- expected %v, received %v", "POST", r.Method)
}
if h := r.Header.Get("Metadata"); h != "true" {
t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set Metadata header for MSI")
}
return resp, nil
})
}
})())
spt.SetSender(s)
err = spt.Refresh()
if err != nil {
t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
}
if body.IsOpen() {
t.Fatalf("the response was not closed!")
}
}
func TestServicePrincipalTokenRefreshSetsMimeType(t *testing.T) {
spt := newServicePrincipalToken()
body := mocks.NewBody(newTokenJSON("test", "test"))
resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
if r.Header.Get(http.CanonicalHeaderKey("Content-Type")) != "application/x-www-form-urlencoded" {
t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set Content-Type -- expected %v, received %v",
"application/x-form-urlencoded",
r.Header.Get(http.CanonicalHeaderKey("Content-Type")))
}
return resp, nil
})
}
})())
spt.SetSender(s)
err := spt.Refresh()
if err != nil {
t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
}
}
func TestServicePrincipalTokenRefreshSetsURL(t *testing.T) {
spt := newServicePrincipalToken()
body := mocks.NewBody(newTokenJSON("test", "test"))
resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.String() != TestOAuthConfig.TokenEndpoint.String() {
t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set the URL -- expected %v, received %v",
TestOAuthConfig.TokenEndpoint, r.URL)
}
return resp, nil
})
}
})())
spt.SetSender(s)
err := spt.Refresh()
if err != nil {
t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
}
}
func testServicePrincipalTokenRefreshSetsBody(t *testing.T, spt *ServicePrincipalToken, f func(*testing.T, []byte)) {
body := mocks.NewBody(newTokenJSON("test", "test"))
resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
b, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("adal: Failed to read body of Service Principal token request (%v)", err)
}
f(t, b)
return resp, nil
})
}
})())
spt.SetSender(s)
err := spt.Refresh()
if err != nil {
t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
}
}
func TestServicePrincipalTokenManualRefreshSetsBody(t *testing.T) {
sptManual := newServicePrincipalTokenManual()
testServicePrincipalTokenRefreshSetsBody(t, sptManual, func(t *testing.T, b []byte) {
if string(b) != defaultManualFormData {
t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set the HTTP Request Body -- expected %v, received %v",
defaultManualFormData, string(b))
}
})
}
func TestServicePrincipalTokenCertficateRefreshSetsBody(t *testing.T) {
sptCert := newServicePrincipalTokenCertificate(t)
testServicePrincipalTokenRefreshSetsBody(t, sptCert, func(t *testing.T, b []byte) {
body := string(b)
values, _ := url.ParseQuery(body)
if values["client_assertion_type"][0] != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" ||
values["client_id"][0] != "id" ||
values["grant_type"][0] != "client_credentials" ||
values["resource"][0] != "resource" {
t.Fatalf("adal: ServicePrincipalTokenCertificate#Refresh did not correctly set the HTTP Request Body.")
}
})
}
func TestServicePrincipalTokenSecretRefreshSetsBody(t *testing.T) {
spt := newServicePrincipalToken()
testServicePrincipalTokenRefreshSetsBody(t, spt, func(t *testing.T, b []byte) {
if string(b) != defaultFormData {
t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set the HTTP Request Body -- expected %v, received %v",
defaultFormData, string(b))
}
})
}
func TestServicePrincipalTokenRefreshClosesRequestBody(t *testing.T) {
spt := newServicePrincipalToken()
body := mocks.NewBody(newTokenJSON("test", "test"))
resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
return resp, nil
})
}
})())
spt.SetSender(s)
err := spt.Refresh()
if err != nil {
t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
}
if resp.Body.(*mocks.Body).IsOpen() {
t.Fatal("adal: ServicePrincipalToken#Refresh failed to close the HTTP Response Body")
}
}
func TestServicePrincipalTokenRefreshRejectsResponsesWithStatusNotOK(t *testing.T) {
spt := newServicePrincipalToken()
body := mocks.NewBody(newTokenJSON("test", "test"))
resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusUnauthorized, "Unauthorized")
c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
return resp, nil
})
}
})())
spt.SetSender(s)
err := spt.Refresh()
if err == nil {
t.Fatalf("adal: ServicePrincipalToken#Refresh should reject a response with status != %d", http.StatusOK)
}
}
func TestServicePrincipalTokenRefreshRejectsEmptyBody(t *testing.T) {
spt := newServicePrincipalToken()
c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
return mocks.NewResponse(), nil
})
}
})())
spt.SetSender(s)
err := spt.Refresh()
if err == nil {
t.Fatal("adal: ServicePrincipalToken#Refresh should reject an empty token")
}
}
func TestServicePrincipalTokenRefreshPropagatesErrors(t *testing.T) {
spt := newServicePrincipalToken()
c := mocks.NewSender()
c.SetError(fmt.Errorf("Faux Error"))
spt.SetSender(c)
err := spt.Refresh()
if err == nil {
t.Fatal("adal: Failed to propagate the request error")
}
}
func TestServicePrincipalTokenRefreshReturnsErrorIfNotOk(t *testing.T) {
spt := newServicePrincipalToken()
c := mocks.NewSender()
c.AppendResponse(mocks.NewResponseWithStatus("401 NotAuthorized", http.StatusUnauthorized))
spt.SetSender(c)
err := spt.Refresh()
if err == nil {
t.Fatalf("adal: Failed to return an when receiving a status code other than HTTP %d", http.StatusOK)
}
}
func TestServicePrincipalTokenRefreshUnmarshals(t *testing.T) {
spt := newServicePrincipalToken()
expiresOn := strconv.Itoa(int(time.Now().Add(3600 * time.Second).Sub(date.UnixEpoch()).Seconds()))
j := newTokenJSON(expiresOn, "resource")
resp := mocks.NewResponseWithContent(j)
c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
return resp, nil
})
}
})())
spt.SetSender(s)
err := spt.Refresh()
if err != nil {
t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
} else if spt.AccessToken != "accessToken" ||
spt.ExpiresIn != "3600" ||
spt.ExpiresOn != expiresOn ||
spt.NotBefore != expiresOn ||
spt.Resource != "resource" ||
spt.Type != "Bearer" {
t.Fatalf("adal: ServicePrincipalToken#Refresh failed correctly unmarshal the JSON -- expected %v, received %v",
j, *spt)
}
}
func TestServicePrincipalTokenEnsureFreshRefreshes(t *testing.T) {
spt := newServicePrincipalToken()
expireToken(&spt.Token)
body := mocks.NewBody(newTokenJSON("test", "test"))
resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
f := false
c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
f = true
return resp, nil
})
}
})())
spt.SetSender(s)
err := spt.EnsureFresh()
if err != nil {
t.Fatalf("adal: ServicePrincipalToken#EnsureFresh returned an unexpected error (%v)", err)
}
if !f {
t.Fatal("adal: ServicePrincipalToken#EnsureFresh failed to call Refresh for stale token")
}
}
func TestServicePrincipalTokenEnsureFreshSkipsIfFresh(t *testing.T) {
spt := newServicePrincipalToken()
setTokenToExpireIn(&spt.Token, 1000*time.Second)
f := false
c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
f = true
return mocks.NewResponse(), nil
})
}
})())
spt.SetSender(s)
err := spt.EnsureFresh()
if err != nil {
t.Fatalf("adal: ServicePrincipalToken#EnsureFresh returned an unexpected error (%v)", err)
}
if f {
t.Fatal("adal: ServicePrincipalToken#EnsureFresh invoked Refresh for fresh token")
}
}
func TestRefreshCallback(t *testing.T) {
callbackTriggered := false
spt := newServicePrincipalToken(func(Token) error {
callbackTriggered = true
return nil
})
expiresOn := strconv.Itoa(int(time.Now().Add(3600 * time.Second).Sub(date.UnixEpoch()).Seconds()))
sender := mocks.NewSender()
j := newTokenJSON(expiresOn, "resource")
sender.AppendResponse(mocks.NewResponseWithContent(j))
spt.SetSender(sender)
err := spt.Refresh()
if err != nil {
t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
}
if !callbackTriggered {
t.Fatalf("adal: RefreshCallback failed to trigger call callback")
}
}
func TestRefreshCallbackErrorPropagates(t *testing.T) {
errorText := "this is an error text"
spt := newServicePrincipalToken(func(Token) error {
return fmt.Errorf(errorText)
})
expiresOn := strconv.Itoa(int(time.Now().Add(3600 * time.Second).Sub(date.UnixEpoch()).Seconds()))
sender := mocks.NewSender()
j := newTokenJSON(expiresOn, "resource")
sender.AppendResponse(mocks.NewResponseWithContent(j))
spt.SetSender(sender)
err := spt.Refresh()
if err == nil || !strings.Contains(err.Error(), errorText) {
t.Fatalf("adal: RefreshCallback failed to propagate error")
}
}
// This demonstrates the danger of manual token without a refresh token
func TestServicePrincipalTokenManualRefreshFailsWithoutRefresh(t *testing.T) {
spt := newServicePrincipalTokenManual()
spt.RefreshToken = ""
err := spt.Refresh()
if err == nil {
t.Fatalf("adal: ServicePrincipalToken#Refresh should have failed with a ManualTokenSecret without a refresh token")
}
}
func TestNewServicePrincipalTokenFromMSI(t *testing.T) {
resource := "https://resource"
cb := func(token Token) error { return nil }
spt, err := NewServicePrincipalTokenFromMSI("http://msiendpoint/", resource, cb)
if err != nil {
t.Fatalf("Failed to get MSI SPT: %v", err)
}
// check some of the SPT fields
if _, ok := spt.secret.(*ServicePrincipalMSISecret); !ok {
t.Fatal("SPT secret was not of MSI type")
}
if spt.resource != resource {
t.Fatal("SPT came back with incorrect resource")
}
if len(spt.refreshCallbacks) != 1 {
t.Fatal("SPT had incorrect refresh callbacks.")
}
}
func TestGetVMEndpoint(t *testing.T) {
tempSettingsFile, err := ioutil.TempFile("", "ManagedIdentity-Settings")
if err != nil {
t.Fatal("Couldn't write temp settings file")
}
defer os.Remove(tempSettingsFile.Name())
settingsContents := []byte(`{
"url": "http://msiendpoint/"
}`)
if _, err := tempSettingsFile.Write(settingsContents); err != nil {
t.Fatal("Couldn't fill temp settings file")
}
endpoint, err := getMSIVMEndpoint(tempSettingsFile.Name())
if err != nil {
t.Fatal("Coudn't get VM endpoint")
}
if endpoint != "http://msiendpoint/" {
t.Fatal("Didn't get correct endpoint")
}
}
func newToken() *Token {
return &Token{
AccessToken: "ASECRETVALUE",
Resource: "https://azure.microsoft.com/",
Type: "Bearer",
}
}
func newTokenJSON(expiresOn string, resource string) string {
return fmt.Sprintf(`{
"access_token" : "accessToken",
"expires_in" : "3600",
"expires_on" : "%s",
"not_before" : "%s",
"resource" : "%s",
"token_type" : "Bearer"
}`,
expiresOn, expiresOn, resource)
}
func newTokenExpiresIn(expireIn time.Duration) *Token {
return setTokenToExpireIn(newToken(), expireIn)
}
func newTokenExpiresAt(expireAt time.Time) *Token {
return setTokenToExpireAt(newToken(), expireAt)
}
func expireToken(t *Token) *Token {
return setTokenToExpireIn(t, 0)
}
func setTokenToExpireAt(t *Token, expireAt time.Time) *Token {
t.ExpiresIn = "3600"
t.ExpiresOn = strconv.Itoa(int(expireAt.Sub(date.UnixEpoch()).Seconds()))
t.NotBefore = t.ExpiresOn
return t
}
func setTokenToExpireIn(t *Token, expireIn time.Duration) *Token {
return setTokenToExpireAt(t, time.Now().Add(expireIn))
}
func newServicePrincipalToken(callbacks ...TokenRefreshCallback) *ServicePrincipalToken {
spt, _ := NewServicePrincipalToken(TestOAuthConfig, "id", "secret", "resource", callbacks...)
return spt
}
func newServicePrincipalTokenManual() *ServicePrincipalToken {
token := newToken()
token.RefreshToken = "refreshtoken"
spt, _ := NewServicePrincipalTokenFromManualToken(TestOAuthConfig, "id", "resource", *token)
return spt
}
func newServicePrincipalTokenCertificate(t *testing.T) *ServicePrincipalToken {
template := x509.Certificate{
SerialNumber: big.NewInt(0),
Subject: pkix.Name{CommonName: "test"},
BasicConstraintsValid: true,
}
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}
certificateBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
if err != nil {
t.Fatal(err)
}
certificate, err := x509.ParseCertificate(certificateBytes)
if err != nil {
t.Fatal(err)
}
spt, _ := NewServicePrincipalTokenFromCertificate(TestOAuthConfig, "id", certificate, privateKey, "resource")
return spt
}