Update go dependencies

This commit is contained in:
Manuel de Brito Fontes 2018-05-26 11:27:53 -04:00 committed by Manuel Alejandro de Brito Fontes
parent 15ffb51394
commit bb4d483837
No known key found for this signature in database
GPG key ID: 786136016A8BA02A
1621 changed files with 86368 additions and 284392 deletions

View file

@ -1,5 +1,98 @@
# CHANGELOG
## v9.9.0
### New Features
- Added EventGridKeyAuthorizer for key authorization with event grid topics.
### Bug Fixes
- Fixed race condition when auto-refreshing service principal tokens.
## v9.8.1
### Bug Fixes
- Added http.StatusNoContent (204) to the list of expected status codes for long-running operations.
- Updated runtime version info so it's current.
## v9.8.0
### New Features
- Added type azure.AsyncOpIncompleteError to be returned from a future's Result() method when the operation has not completed.
## v9.7.1
### Bug Fixes
- Use correct AAD and Graph endpoints for US Gov environment.
## v9.7.0
### New Features
- Added support for application/octet-stream MIME types.
## v9.6.1
### Bug Fixes
- Ensure Authorization header is added to request when polling for registration status.
## v9.6.0
### New Features
- Added support for acquiring tokens via MSI with a user assigned identity.
## v9.5.3
### Bug Fixes
- Don't remove encoding of existing URL Query parameters when calling autorest.WithQueryParameters.
- Set correct Content Type when using autorest.WithFormData.
## v9.5.2
### Bug Fixes
- Check for nil *http.Response before dereferencing it.
## v9.5.1
### Bug Fixes
- Don't count http.StatusTooManyRequests (429) against the retry cap.
- Use retry logic when SkipResourceProviderRegistration is set to true.
## v9.5.0
### New Features
- Added support for username + password, API key, authoriazation code and cognitive services authentication.
- Added field SkipResourceProviderRegistration to clients to provide a way to skip auto-registration of RPs.
- Added utility function AsStringSlice() to convert its parameters to a string slice.
### Bug Fixes
- When checking for authentication failures look at the error type not the status code as it could vary.
## v9.4.2
### Bug Fixes
- Validate parameters when creating credentials.
- Don't retry requests if the returned status is a 401 (http.StatusUnauthorized) as it will never succeed.
## v9.4.1
### Bug Fixes
- Update the AccessTokensPath() to read access tokens path through AZURE_ACCESS_TOKEN_FILE. If this
environment variable is not set, it will fall back to use default path set by Azure CLI.
- Use case-insensitive string comparison for polling states.
## v9.4.0
### New Features
@ -106,7 +199,7 @@ Support for UNIX time.
- Added telemetry.
## v7.2.3
- Fixing bug in calls to `DelayForBackoff` that caused doubling of delay
- Fixing bug in calls to `DelayForBackoff` that caused doubling of delay
duration.
## v7.2.2

View file

@ -218,6 +218,40 @@ if (err == nil) {
}
```
#### Username password authenticate
```Go
spt, err := adal.NewServicePrincipalTokenFromUsernamePassword(
oauthConfig,
applicationID,
username,
password,
resource,
callbacks...)
if (err == nil) {
token := spt.Token
}
```
#### Authorization code authenticate
``` Go
spt, err := adal.NewServicePrincipalTokenFromAuthorizationCode(
oauthConfig,
applicationID,
clientSecret,
authorizationCode,
redirectURI,
resource,
callbacks...)
err = spt.Refresh()
if (err == nil) {
token := spt.Token
}
```
### Command Line Tool
A command line tool is available in `cmd/adal.go` that can acquire a token for a given resource. It supports all flows mentioned above.

View file

@ -32,8 +32,24 @@ type OAuthConfig struct {
DeviceCodeEndpoint url.URL
}
// IsZero returns true if the OAuthConfig object is zero-initialized.
func (oac OAuthConfig) IsZero() bool {
return oac == OAuthConfig{}
}
func validateStringParam(param, name string) error {
if len(param) == 0 {
return fmt.Errorf("parameter '" + name + "' cannot be empty")
}
return nil
}
// NewOAuthConfig returns an OAuthConfig with tenant specific urls
func NewOAuthConfig(activeDirectoryEndpoint, tenantID string) (*OAuthConfig, error) {
if err := validateStringParam(activeDirectoryEndpoint, "activeDirectoryEndpoint"); err != nil {
return nil, err
}
// it's legal for tenantID to be empty so don't validate it
const activeDirectoryEndpointTemplate = "%s/oauth2/%s?api-version=%s"
u, err := url.Parse(activeDirectoryEndpoint)
if err != nil {

View file

@ -1,44 +0,0 @@
package adal
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"testing"
)
func TestNewOAuthConfig(t *testing.T) {
const testActiveDirectoryEndpoint = "https://login.test.com"
const testTenantID = "tenant-id-test"
config, err := NewOAuthConfig(testActiveDirectoryEndpoint, testTenantID)
if err != nil {
t.Fatalf("autorest/adal: Unexpected error while creating oauth configuration for tenant: %v.", err)
}
expected := "https://login.test.com/tenant-id-test/oauth2/authorize?api-version=1.0"
if config.AuthorizeEndpoint.String() != expected {
t.Fatalf("autorest/adal: Incorrect authorize url for Tenant from Environment. expected(%s). actual(%v).", expected, config.AuthorizeEndpoint)
}
expected = "https://login.test.com/tenant-id-test/oauth2/token?api-version=1.0"
if config.TokenEndpoint.String() != expected {
t.Fatalf("autorest/adal: Incorrect authorize url for Tenant from Environment. expected(%s). actual(%v).", expected, config.TokenEndpoint)
}
expected = "https://login.test.com/tenant-id-test/oauth2/devicecode?api-version=1.0"
if config.DeviceCodeEndpoint.String() != expected {
t.Fatalf("autorest/adal Incorrect devicecode url for Tenant from Environment. expected(%s). actual(%v).", expected, config.DeviceCodeEndpoint)
}
}

View file

@ -1,330 +0,0 @@
package adal
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"encoding/json"
"fmt"
"net/http"
"strings"
"testing"
"github.com/Azure/go-autorest/autorest/mocks"
)
const (
TestResource = "SomeResource"
TestClientID = "SomeClientID"
TestTenantID = "SomeTenantID"
TestActiveDirectoryEndpoint = "https://login.test.com/"
)
var (
testOAuthConfig, _ = NewOAuthConfig(TestActiveDirectoryEndpoint, TestTenantID)
TestOAuthConfig = *testOAuthConfig
)
const MockDeviceCodeResponse = `
{
"device_code": "10000-40-1234567890",
"user_code": "ABCDEF",
"verification_url": "http://aka.ms/deviceauth",
"expires_in": "900",
"interval": "0"
}
`
const MockDeviceTokenResponse = `{
"access_token": "accessToken",
"refresh_token": "refreshToken",
"expires_in": "1000",
"expires_on": "2000",
"not_before": "3000",
"resource": "resource",
"token_type": "type"
}
`
func TestDeviceCodeIncludesResource(t *testing.T) {
sender := mocks.NewSender()
sender.AppendResponse(mocks.NewResponseWithContent(MockDeviceCodeResponse))
code, err := InitiateDeviceAuth(sender, TestOAuthConfig, TestClientID, TestResource)
if err != nil {
t.Fatalf("adal: unexpected error initiating device auth")
}
if code.Resource != TestResource {
t.Fatalf("adal: InitiateDeviceAuth failed to stash the resource in the DeviceCode struct")
}
}
func TestDeviceCodeReturnsErrorIfSendingFails(t *testing.T) {
sender := mocks.NewSender()
sender.SetError(fmt.Errorf("this is an error"))
_, err := InitiateDeviceAuth(sender, TestOAuthConfig, TestClientID, TestResource)
if err == nil || !strings.Contains(err.Error(), errCodeSendingFails) {
t.Fatalf("adal: failed to get correct error expected(%s) actual(%s)", errCodeSendingFails, err.Error())
}
}
func TestDeviceCodeReturnsErrorIfBadRequest(t *testing.T) {
sender := mocks.NewSender()
body := mocks.NewBody("doesn't matter")
sender.AppendResponse(mocks.NewResponseWithBodyAndStatus(body, http.StatusBadRequest, "Bad Request"))
_, err := InitiateDeviceAuth(sender, TestOAuthConfig, TestClientID, TestResource)
if err == nil || !strings.Contains(err.Error(), errCodeHandlingFails) {
t.Fatalf("adal: failed to get correct error expected(%s) actual(%s)", errCodeHandlingFails, err.Error())
}
if body.IsOpen() {
t.Fatalf("response body was left open!")
}
}
func TestDeviceCodeReturnsErrorIfCannotDeserializeDeviceCode(t *testing.T) {
gibberishJSON := strings.Replace(MockDeviceCodeResponse, "expires_in", "\":, :gibberish", -1)
sender := mocks.NewSender()
body := mocks.NewBody(gibberishJSON)
sender.AppendResponse(mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK"))
_, err := InitiateDeviceAuth(sender, TestOAuthConfig, TestClientID, TestResource)
if err == nil || !strings.Contains(err.Error(), errCodeHandlingFails) {
t.Fatalf("adal: failed to get correct error expected(%s) actual(%s)", errCodeHandlingFails, err.Error())
}
if body.IsOpen() {
t.Fatalf("response body was left open!")
}
}
func TestDeviceCodeReturnsErrorIfEmptyDeviceCode(t *testing.T) {
sender := mocks.NewSender()
body := mocks.NewBody("")
sender.AppendResponse(mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK"))
_, err := InitiateDeviceAuth(sender, TestOAuthConfig, TestClientID, TestResource)
if err != ErrDeviceCodeEmpty {
t.Fatalf("adal: failed to get correct error expected(%s) actual(%s)", ErrDeviceCodeEmpty, err.Error())
}
if body.IsOpen() {
t.Fatalf("response body was left open!")
}
}
func deviceCode() *DeviceCode {
var deviceCode DeviceCode
_ = json.Unmarshal([]byte(MockDeviceCodeResponse), &deviceCode)
deviceCode.Resource = TestResource
deviceCode.ClientID = TestClientID
return &deviceCode
}
func TestDeviceTokenReturns(t *testing.T) {
sender := mocks.NewSender()
body := mocks.NewBody(MockDeviceTokenResponse)
sender.AppendResponse(mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK"))
_, err := WaitForUserCompletion(sender, deviceCode())
if err != nil {
t.Fatalf("adal: got error unexpectedly")
}
if body.IsOpen() {
t.Fatalf("response body was left open!")
}
}
func TestDeviceTokenReturnsErrorIfSendingFails(t *testing.T) {
sender := mocks.NewSender()
sender.SetError(fmt.Errorf("this is an error"))
_, err := WaitForUserCompletion(sender, deviceCode())
if err == nil || !strings.Contains(err.Error(), errTokenSendingFails) {
t.Fatalf("adal: failed to get correct error expected(%s) actual(%s)", errTokenSendingFails, err.Error())
}
}
func TestDeviceTokenReturnsErrorIfServerError(t *testing.T) {
sender := mocks.NewSender()
body := mocks.NewBody("")
sender.AppendResponse(mocks.NewResponseWithBodyAndStatus(body, http.StatusInternalServerError, "Internal Server Error"))
_, err := WaitForUserCompletion(sender, deviceCode())
if err == nil || !strings.Contains(err.Error(), errTokenHandlingFails) {
t.Fatalf("adal: failed to get correct error expected(%s) actual(%s)", errTokenHandlingFails, err.Error())
}
if body.IsOpen() {
t.Fatalf("response body was left open!")
}
}
func TestDeviceTokenReturnsErrorIfCannotDeserializeDeviceToken(t *testing.T) {
gibberishJSON := strings.Replace(MockDeviceTokenResponse, "expires_in", ";:\"gibberish", -1)
sender := mocks.NewSender()
body := mocks.NewBody(gibberishJSON)
sender.AppendResponse(mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK"))
_, err := WaitForUserCompletion(sender, deviceCode())
if err == nil || !strings.Contains(err.Error(), errTokenHandlingFails) {
t.Fatalf("adal: failed to get correct error expected(%s) actual(%s)", errTokenHandlingFails, err.Error())
}
if body.IsOpen() {
t.Fatalf("response body was left open!")
}
}
func errorDeviceTokenResponse(message string) string {
return `{ "error": "` + message + `" }`
}
func TestDeviceTokenReturnsErrorIfAuthorizationPending(t *testing.T) {
sender := mocks.NewSender()
body := mocks.NewBody(errorDeviceTokenResponse("authorization_pending"))
sender.AppendResponse(mocks.NewResponseWithBodyAndStatus(body, http.StatusBadRequest, "Bad Request"))
_, err := CheckForUserCompletion(sender, deviceCode())
if err != ErrDeviceAuthorizationPending {
t.Fatalf("!!!")
}
if body.IsOpen() {
t.Fatalf("response body was left open!")
}
}
func TestDeviceTokenReturnsErrorIfSlowDown(t *testing.T) {
sender := mocks.NewSender()
body := mocks.NewBody(errorDeviceTokenResponse("slow_down"))
sender.AppendResponse(mocks.NewResponseWithBodyAndStatus(body, http.StatusBadRequest, "Bad Request"))
_, err := CheckForUserCompletion(sender, deviceCode())
if err != ErrDeviceSlowDown {
t.Fatalf("!!!")
}
if body.IsOpen() {
t.Fatalf("response body was left open!")
}
}
type deviceTokenSender struct {
errorString string
attempts int
}
func newDeviceTokenSender(deviceErrorString string) *deviceTokenSender {
return &deviceTokenSender{errorString: deviceErrorString, attempts: 0}
}
func (s *deviceTokenSender) Do(req *http.Request) (*http.Response, error) {
var resp *http.Response
if s.attempts < 1 {
s.attempts++
resp = mocks.NewResponseWithContent(errorDeviceTokenResponse(s.errorString))
} else {
resp = mocks.NewResponseWithContent(MockDeviceTokenResponse)
}
return resp, nil
}
// since the above only exercise CheckForUserCompletion, we repeat the test here,
// but with the intent of showing that WaitForUserCompletion loops properly.
func TestDeviceTokenSucceedsWithIntermediateAuthPending(t *testing.T) {
sender := newDeviceTokenSender("authorization_pending")
_, err := WaitForUserCompletion(sender, deviceCode())
if err != nil {
t.Fatalf("unexpected error occurred")
}
}
// same as above but with SlowDown now
func TestDeviceTokenSucceedsWithIntermediateSlowDown(t *testing.T) {
sender := newDeviceTokenSender("slow_down")
_, err := WaitForUserCompletion(sender, deviceCode())
if err != nil {
t.Fatalf("unexpected error occurred")
}
}
func TestDeviceTokenReturnsErrorIfAccessDenied(t *testing.T) {
sender := mocks.NewSender()
body := mocks.NewBody(errorDeviceTokenResponse("access_denied"))
sender.AppendResponse(mocks.NewResponseWithBodyAndStatus(body, http.StatusBadRequest, "Bad Request"))
_, err := WaitForUserCompletion(sender, deviceCode())
if err != ErrDeviceAccessDenied {
t.Fatalf("adal: got wrong error expected(%s) actual(%s)", ErrDeviceAccessDenied.Error(), err.Error())
}
if body.IsOpen() {
t.Fatalf("response body was left open!")
}
}
func TestDeviceTokenReturnsErrorIfCodeExpired(t *testing.T) {
sender := mocks.NewSender()
body := mocks.NewBody(errorDeviceTokenResponse("code_expired"))
sender.AppendResponse(mocks.NewResponseWithBodyAndStatus(body, http.StatusBadRequest, "Bad Request"))
_, err := WaitForUserCompletion(sender, deviceCode())
if err != ErrDeviceCodeExpired {
t.Fatalf("adal: got wrong error expected(%s) actual(%s)", ErrDeviceCodeExpired.Error(), err.Error())
}
if body.IsOpen() {
t.Fatalf("response body was left open!")
}
}
func TestDeviceTokenReturnsErrorForUnknownError(t *testing.T) {
sender := mocks.NewSender()
body := mocks.NewBody(errorDeviceTokenResponse("unknown_error"))
sender.AppendResponse(mocks.NewResponseWithBodyAndStatus(body, http.StatusBadRequest, "Bad Request"))
_, err := WaitForUserCompletion(sender, deviceCode())
if err == nil {
t.Fatalf("failed to get error")
}
if err != ErrDeviceGeneric {
t.Fatalf("adal: got wrong error expected(%s) actual(%s)", ErrDeviceGeneric.Error(), err.Error())
}
if body.IsOpen() {
t.Fatalf("response body was left open!")
}
}
func TestDeviceTokenReturnsErrorIfTokenEmptyAndStatusOK(t *testing.T) {
sender := mocks.NewSender()
body := mocks.NewBody("")
sender.AppendResponse(mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK"))
_, err := WaitForUserCompletion(sender, deviceCode())
if err != ErrOAuthTokenEmpty {
t.Fatalf("adal: got wrong error expected(%s) actual(%s)", ErrOAuthTokenEmpty.Error(), err.Error())
}
if body.IsOpen() {
t.Fatalf("response body was left open!")
}
}

View file

@ -1,171 +0,0 @@
package adal
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"encoding/json"
"io/ioutil"
"os"
"path"
"reflect"
"runtime"
"strings"
"testing"
)
const MockTokenJSON string = `{
"access_token": "accessToken",
"refresh_token": "refreshToken",
"expires_in": "1000",
"expires_on": "2000",
"not_before": "3000",
"resource": "resource",
"token_type": "type"
}`
var TestToken = Token{
AccessToken: "accessToken",
RefreshToken: "refreshToken",
ExpiresIn: "1000",
ExpiresOn: "2000",
NotBefore: "3000",
Resource: "resource",
Type: "type",
}
func writeTestTokenFile(t *testing.T, suffix string, contents string) *os.File {
f, err := ioutil.TempFile(os.TempDir(), suffix)
if err != nil {
t.Fatalf("azure: unexpected error when creating temp file: %v", err)
}
defer f.Close()
_, err = f.Write([]byte(contents))
if err != nil {
t.Fatalf("azure: unexpected error when writing temp test file: %v", err)
}
return f
}
func TestLoadToken(t *testing.T) {
f := writeTestTokenFile(t, "testloadtoken", MockTokenJSON)
defer os.Remove(f.Name())
expectedToken := TestToken
actualToken, err := LoadToken(f.Name())
if err != nil {
t.Fatalf("azure: unexpected error loading token from file: %v", err)
}
if *actualToken != expectedToken {
t.Fatalf("azure: failed to decode properly expected(%v) actual(%v)", expectedToken, *actualToken)
}
// test that LoadToken closes the file properly
err = SaveToken(f.Name(), 0600, *actualToken)
if err != nil {
t.Fatalf("azure: could not save token after LoadToken: %v", err)
}
}
func TestLoadTokenFailsBadPath(t *testing.T) {
_, err := LoadToken("/tmp/this_file_should_never_exist_really")
expectedSubstring := "failed to open file"
if err == nil || !strings.Contains(err.Error(), expectedSubstring) {
t.Fatalf("azure: failed to get correct error expected(%s) actual(%s)", expectedSubstring, err.Error())
}
}
func TestLoadTokenFailsBadJson(t *testing.T) {
gibberishJSON := strings.Replace(MockTokenJSON, "expires_on", ";:\"gibberish", -1)
f := writeTestTokenFile(t, "testloadtokenfailsbadjson", gibberishJSON)
defer os.Remove(f.Name())
_, err := LoadToken(f.Name())
expectedSubstring := "failed to decode contents of file"
if err == nil || !strings.Contains(err.Error(), expectedSubstring) {
t.Fatalf("azure: failed to get correct error expected(%s) actual(%s)", expectedSubstring, err.Error())
}
}
func token() *Token {
var token Token
json.Unmarshal([]byte(MockTokenJSON), &token)
return &token
}
func TestSaveToken(t *testing.T) {
f, err := ioutil.TempFile("", "testloadtoken")
if err != nil {
t.Fatalf("azure: unexpected error when creating temp file: %v", err)
}
defer os.Remove(f.Name())
f.Close()
mode := os.ModePerm & 0642
err = SaveToken(f.Name(), mode, *token())
if err != nil {
t.Fatalf("azure: unexpected error saving token to file: %v", err)
}
fi, err := os.Stat(f.Name()) // open a new stat as held ones are not fresh
if err != nil {
t.Fatalf("azure: stat failed: %v", err)
}
if runtime.GOOS != "windows" { // permissions don't work on Windows
if perm := fi.Mode().Perm(); perm != mode {
t.Fatalf("azure: wrong file perm. got:%s; expected:%s file :%s", perm, mode, f.Name())
}
}
var actualToken Token
var expectedToken Token
json.Unmarshal([]byte(MockTokenJSON), expectedToken)
contents, err := ioutil.ReadFile(f.Name())
if err != nil {
t.Fatal("!!")
}
json.Unmarshal(contents, actualToken)
if !reflect.DeepEqual(actualToken, expectedToken) {
t.Fatal("azure: token was not serialized correctly")
}
}
func TestSaveTokenFailsNoPermission(t *testing.T) {
pathWhereWeShouldntHavePermission := "/usr/thiswontwork/atall"
if runtime.GOOS == "windows" {
pathWhereWeShouldntHavePermission = path.Join(os.Getenv("windir"), "system32\\mytokendir\\mytoken")
}
err := SaveToken(pathWhereWeShouldntHavePermission, 0644, *token())
expectedSubstring := "failed to create directory"
if err == nil || !strings.Contains(err.Error(), expectedSubstring) {
t.Fatalf("azure: failed to get correct error expected(%s) actual(%v)", expectedSubstring, err)
}
}
func TestSaveTokenFailsCantCreate(t *testing.T) {
tokenPath := "/thiswontwork"
if runtime.GOOS == "windows" {
tokenPath = path.Join(os.Getenv("windir"), "system32")
}
err := SaveToken(tokenPath, 0644, *token())
expectedSubstring := "failed to create the temp file to write the token"
if err == nil || !strings.Contains(err.Error(), expectedSubstring) {
t.Fatalf("azure: failed to get correct error expected(%s) actual(%v)", expectedSubstring, err)
}
}

View file

@ -27,6 +27,7 @@ import (
"net/url"
"strconv"
"strings"
"sync"
"time"
"github.com/Azure/go-autorest/autorest/date"
@ -42,9 +43,15 @@ const (
// OAuthGrantTypeClientCredentials is the "grant_type" identifier used in credential flows
OAuthGrantTypeClientCredentials = "client_credentials"
// OAuthGrantTypeUserPass is the "grant_type" identifier used in username and password auth flows
OAuthGrantTypeUserPass = "password"
// OAuthGrantTypeRefreshToken is the "grant_type" identifier used in refresh token flows
OAuthGrantTypeRefreshToken = "refresh_token"
// OAuthGrantTypeAuthorizationCode is the "grant_type" identifier used in authorization code flows
OAuthGrantTypeAuthorizationCode = "authorization_code"
// metadataHeader is the header required by MSI extension
metadataHeader = "Metadata"
)
@ -54,6 +61,12 @@ type OAuthTokenProvider interface {
OAuthToken() string
}
// TokenRefreshError is an interface used by errors returned during token refresh.
type TokenRefreshError interface {
error
Response() *http.Response
}
// Refresher is an interface for token refresh functionality
type Refresher interface {
Refresh() error
@ -78,6 +91,11 @@ type Token struct {
Type string `json:"token_type"`
}
// IsZero returns true if the token object is zero-initialized.
func (t Token) IsZero() bool {
return t == Token{}
}
// Expires returns the time.Time when the Token expires.
func (t Token) Expires() time.Time {
s, err := strconv.Atoi(t.ExpiresOn)
@ -145,6 +163,34 @@ type ServicePrincipalCertificateSecret struct {
type ServicePrincipalMSISecret struct {
}
// ServicePrincipalUsernamePasswordSecret implements ServicePrincipalSecret for username and password auth.
type ServicePrincipalUsernamePasswordSecret struct {
Username string
Password string
}
// ServicePrincipalAuthorizationCodeSecret implements ServicePrincipalSecret for authorization code auth.
type ServicePrincipalAuthorizationCodeSecret struct {
ClientSecret string
AuthorizationCode string
RedirectURI string
}
// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
func (secret *ServicePrincipalAuthorizationCodeSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
v.Set("code", secret.AuthorizationCode)
v.Set("client_secret", secret.ClientSecret)
v.Set("redirect_uri", secret.RedirectURI)
return nil
}
// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
func (secret *ServicePrincipalUsernamePasswordSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
v.Set("username", secret.Username)
v.Set("password", secret.Password)
return nil
}
// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
func (msiSecret *ServicePrincipalMSISecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
return nil
@ -199,25 +245,46 @@ func (secret *ServicePrincipalCertificateSecret) SetAuthenticationValues(spt *Se
type ServicePrincipalToken struct {
Token
secret ServicePrincipalSecret
oauthConfig OAuthConfig
clientID string
resource string
autoRefresh bool
refreshWithin time.Duration
sender Sender
secret ServicePrincipalSecret
oauthConfig OAuthConfig
clientID string
resource string
autoRefresh bool
autoRefreshLock *sync.Mutex
refreshWithin time.Duration
sender Sender
refreshCallbacks []TokenRefreshCallback
}
func validateOAuthConfig(oac OAuthConfig) error {
if oac.IsZero() {
return fmt.Errorf("parameter 'oauthConfig' cannot be zero-initialized")
}
return nil
}
// NewServicePrincipalTokenWithSecret create a ServicePrincipalToken using the supplied ServicePrincipalSecret implementation.
func NewServicePrincipalTokenWithSecret(oauthConfig OAuthConfig, id string, resource string, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateOAuthConfig(oauthConfig); err != nil {
return nil, err
}
if err := validateStringParam(id, "id"); err != nil {
return nil, err
}
if err := validateStringParam(resource, "resource"); err != nil {
return nil, err
}
if secret == nil {
return nil, fmt.Errorf("parameter 'secret' cannot be nil")
}
spt := &ServicePrincipalToken{
oauthConfig: oauthConfig,
secret: secret,
clientID: id,
resource: resource,
autoRefresh: true,
autoRefreshLock: &sync.Mutex{},
refreshWithin: defaultRefresh,
sender: &http.Client{},
refreshCallbacks: callbacks,
@ -227,6 +294,18 @@ func NewServicePrincipalTokenWithSecret(oauthConfig OAuthConfig, id string, reso
// NewServicePrincipalTokenFromManualToken creates a ServicePrincipalToken using the supplied token
func NewServicePrincipalTokenFromManualToken(oauthConfig OAuthConfig, clientID string, resource string, token Token, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateOAuthConfig(oauthConfig); err != nil {
return nil, err
}
if err := validateStringParam(clientID, "clientID"); err != nil {
return nil, err
}
if err := validateStringParam(resource, "resource"); err != nil {
return nil, err
}
if token.IsZero() {
return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized")
}
spt, err := NewServicePrincipalTokenWithSecret(
oauthConfig,
clientID,
@ -245,6 +324,18 @@ func NewServicePrincipalTokenFromManualToken(oauthConfig OAuthConfig, clientID s
// NewServicePrincipalToken creates a ServicePrincipalToken from the supplied Service Principal
// credentials scoped to the named resource.
func NewServicePrincipalToken(oauthConfig OAuthConfig, clientID string, secret string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateOAuthConfig(oauthConfig); err != nil {
return nil, err
}
if err := validateStringParam(clientID, "clientID"); err != nil {
return nil, err
}
if err := validateStringParam(secret, "secret"); err != nil {
return nil, err
}
if err := validateStringParam(resource, "resource"); err != nil {
return nil, err
}
return NewServicePrincipalTokenWithSecret(
oauthConfig,
clientID,
@ -256,8 +347,23 @@ func NewServicePrincipalToken(oauthConfig OAuthConfig, clientID string, secret s
)
}
// NewServicePrincipalTokenFromCertificate create a ServicePrincipalToken from the supplied pkcs12 bytes.
// NewServicePrincipalTokenFromCertificate creates a ServicePrincipalToken from the supplied pkcs12 bytes.
func NewServicePrincipalTokenFromCertificate(oauthConfig OAuthConfig, clientID string, certificate *x509.Certificate, privateKey *rsa.PrivateKey, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateOAuthConfig(oauthConfig); err != nil {
return nil, err
}
if err := validateStringParam(clientID, "clientID"); err != nil {
return nil, err
}
if err := validateStringParam(resource, "resource"); err != nil {
return nil, err
}
if certificate == nil {
return nil, fmt.Errorf("parameter 'certificate' cannot be nil")
}
if privateKey == nil {
return nil, fmt.Errorf("parameter 'privateKey' cannot be nil")
}
return NewServicePrincipalTokenWithSecret(
oauthConfig,
clientID,
@ -270,6 +376,70 @@ func NewServicePrincipalTokenFromCertificate(oauthConfig OAuthConfig, clientID s
)
}
// NewServicePrincipalTokenFromUsernamePassword creates a ServicePrincipalToken from the username and password.
func NewServicePrincipalTokenFromUsernamePassword(oauthConfig OAuthConfig, clientID string, username string, password string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateOAuthConfig(oauthConfig); err != nil {
return nil, err
}
if err := validateStringParam(clientID, "clientID"); err != nil {
return nil, err
}
if err := validateStringParam(username, "username"); err != nil {
return nil, err
}
if err := validateStringParam(password, "password"); err != nil {
return nil, err
}
if err := validateStringParam(resource, "resource"); err != nil {
return nil, err
}
return NewServicePrincipalTokenWithSecret(
oauthConfig,
clientID,
resource,
&ServicePrincipalUsernamePasswordSecret{
Username: username,
Password: password,
},
callbacks...,
)
}
// NewServicePrincipalTokenFromAuthorizationCode creates a ServicePrincipalToken from the
func NewServicePrincipalTokenFromAuthorizationCode(oauthConfig OAuthConfig, clientID string, clientSecret string, authorizationCode string, redirectURI string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateOAuthConfig(oauthConfig); err != nil {
return nil, err
}
if err := validateStringParam(clientID, "clientID"); err != nil {
return nil, err
}
if err := validateStringParam(clientSecret, "clientSecret"); err != nil {
return nil, err
}
if err := validateStringParam(authorizationCode, "authorizationCode"); err != nil {
return nil, err
}
if err := validateStringParam(redirectURI, "redirectURI"); err != nil {
return nil, err
}
if err := validateStringParam(resource, "resource"); err != nil {
return nil, err
}
return NewServicePrincipalTokenWithSecret(
oauthConfig,
clientID,
resource,
&ServicePrincipalAuthorizationCodeSecret{
ClientSecret: clientSecret,
AuthorizationCode: authorizationCode,
RedirectURI: redirectURI,
},
callbacks...,
)
}
// GetMSIVMEndpoint gets the MSI endpoint on Virtual Machines.
func GetMSIVMEndpoint() (string, error) {
return getMSIVMEndpoint(msiPath)
@ -293,7 +463,29 @@ func getMSIVMEndpoint(path string) (string, error) {
}
// NewServicePrincipalTokenFromMSI creates a ServicePrincipalToken via the MSI VM Extension.
// It will use the system assigned identity when creating the token.
func NewServicePrincipalTokenFromMSI(msiEndpoint, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
return newServicePrincipalTokenFromMSI(msiEndpoint, resource, nil, callbacks...)
}
// NewServicePrincipalTokenFromMSIWithUserAssignedID creates a ServicePrincipalToken via the MSI VM Extension.
// It will use the specified user assigned identity when creating the token.
func NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, resource string, userAssignedID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
return newServicePrincipalTokenFromMSI(msiEndpoint, resource, &userAssignedID, callbacks...)
}
func newServicePrincipalTokenFromMSI(msiEndpoint, resource string, userAssignedID *string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateStringParam(msiEndpoint, "msiEndpoint"); err != nil {
return nil, err
}
if err := validateStringParam(resource, "resource"); err != nil {
return nil, err
}
if userAssignedID != nil {
if err := validateStringParam(*userAssignedID, "userAssignedID"); err != nil {
return nil, err
}
}
// We set the oauth config token endpoint to be MSI's endpoint
msiEndpointURL, err := url.Parse(msiEndpoint)
if err != nil {
@ -310,19 +502,49 @@ func NewServicePrincipalTokenFromMSI(msiEndpoint, resource string, callbacks ...
secret: &ServicePrincipalMSISecret{},
resource: resource,
autoRefresh: true,
autoRefreshLock: &sync.Mutex{},
refreshWithin: defaultRefresh,
sender: &http.Client{},
refreshCallbacks: callbacks,
}
if userAssignedID != nil {
spt.clientID = *userAssignedID
}
return spt, nil
}
// internal type that implements TokenRefreshError
type tokenRefreshError struct {
message string
resp *http.Response
}
// Error implements the error interface which is part of the TokenRefreshError interface.
func (tre tokenRefreshError) Error() string {
return tre.message
}
// Response implements the TokenRefreshError interface, it returns the raw HTTP response from the refresh operation.
func (tre tokenRefreshError) Response() *http.Response {
return tre.resp
}
func newTokenRefreshError(message string, resp *http.Response) TokenRefreshError {
return tokenRefreshError{message: message, resp: resp}
}
// EnsureFresh will refresh the token if it will expire within the refresh window (as set by
// RefreshWithin) and autoRefresh flag is on.
// RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use.
func (spt *ServicePrincipalToken) EnsureFresh() error {
if spt.autoRefresh && spt.WillExpireIn(spt.refreshWithin) {
return spt.Refresh()
// take the lock then check to see if the token was already refreshed
spt.autoRefreshLock.Lock()
defer spt.autoRefreshLock.Unlock()
if spt.WillExpireIn(spt.refreshWithin) {
return spt.Refresh()
}
}
return nil
}
@ -341,15 +563,28 @@ func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error {
}
// Refresh obtains a fresh token for the Service Principal.
// This method is not safe for concurrent use and should be syncrhonized.
func (spt *ServicePrincipalToken) Refresh() error {
return spt.refreshInternal(spt.resource)
}
// RefreshExchange refreshes the token, but for a different resource.
// This method is not safe for concurrent use and should be syncrhonized.
func (spt *ServicePrincipalToken) RefreshExchange(resource string) error {
return spt.refreshInternal(resource)
}
func (spt *ServicePrincipalToken) getGrantType() string {
switch spt.secret.(type) {
case *ServicePrincipalUsernamePasswordSecret:
return OAuthGrantTypeUserPass
case *ServicePrincipalAuthorizationCodeSecret:
return OAuthGrantTypeAuthorizationCode
default:
return OAuthGrantTypeClientCredentials
}
}
func (spt *ServicePrincipalToken) refreshInternal(resource string) error {
v := url.Values{}
v.Set("client_id", spt.clientID)
@ -359,7 +594,7 @@ func (spt *ServicePrincipalToken) refreshInternal(resource string) error {
v.Set("grant_type", OAuthGrantTypeRefreshToken)
v.Set("refresh_token", spt.RefreshToken)
} else {
v.Set("grant_type", OAuthGrantTypeClientCredentials)
v.Set("grant_type", spt.getGrantType())
err := spt.secret.SetAuthenticationValues(spt, &v)
if err != nil {
return err
@ -388,9 +623,9 @@ func (spt *ServicePrincipalToken) refreshInternal(resource string) error {
if resp.StatusCode != http.StatusOK {
if err != nil {
return fmt.Errorf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body", resp.StatusCode)
return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body", resp.StatusCode), resp)
}
return fmt.Errorf("adal: Refresh request failed. Status Code = '%d'. Response body: %s", resp.StatusCode, string(rb))
return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Response body: %s", resp.StatusCode, string(rb)), resp)
}
if err != nil {

View file

@ -1,654 +0,0 @@
package adal
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"io/ioutil"
"math/big"
"net/http"
"net/url"
"os"
"reflect"
"strconv"
"strings"
"testing"
"time"
"github.com/Azure/go-autorest/autorest/date"
"github.com/Azure/go-autorest/autorest/mocks"
)
const (
defaultFormData = "client_id=id&client_secret=secret&grant_type=client_credentials&resource=resource"
defaultManualFormData = "client_id=id&grant_type=refresh_token&refresh_token=refreshtoken&resource=resource"
)
func TestTokenExpires(t *testing.T) {
tt := time.Now().Add(5 * time.Second)
tk := newTokenExpiresAt(tt)
if tk.Expires().Equal(tt) {
t.Fatalf("adal: Token#Expires miscalculated expiration time -- received %v, expected %v", tk.Expires(), tt)
}
}
func TestTokenIsExpired(t *testing.T) {
tk := newTokenExpiresAt(time.Now().Add(-5 * time.Second))
if !tk.IsExpired() {
t.Fatalf("adal: Token#IsExpired failed to mark a stale token as expired -- now %v, token expires at %v",
time.Now().UTC(), tk.Expires())
}
}
func TestTokenIsExpiredUninitialized(t *testing.T) {
tk := &Token{}
if !tk.IsExpired() {
t.Fatalf("adal: An uninitialized Token failed to mark itself as expired (expiration time %v)", tk.Expires())
}
}
func TestTokenIsNoExpired(t *testing.T) {
tk := newTokenExpiresAt(time.Now().Add(1000 * time.Second))
if tk.IsExpired() {
t.Fatalf("adal: Token marked a fresh token as expired -- now %v, token expires at %v", time.Now().UTC(), tk.Expires())
}
}
func TestTokenWillExpireIn(t *testing.T) {
d := 5 * time.Second
tk := newTokenExpiresIn(d)
if !tk.WillExpireIn(d) {
t.Fatal("adal: Token#WillExpireIn mismeasured expiration time")
}
}
func TestServicePrincipalTokenSetAutoRefresh(t *testing.T) {
spt := newServicePrincipalToken()
if !spt.autoRefresh {
t.Fatal("adal: ServicePrincipalToken did not default to automatic token refreshing")
}
spt.SetAutoRefresh(false)
if spt.autoRefresh {
t.Fatal("adal: ServicePrincipalToken#SetAutoRefresh did not disable automatic token refreshing")
}
}
func TestServicePrincipalTokenSetRefreshWithin(t *testing.T) {
spt := newServicePrincipalToken()
if spt.refreshWithin != defaultRefresh {
t.Fatal("adal: ServicePrincipalToken did not correctly set the default refresh interval")
}
spt.SetRefreshWithin(2 * defaultRefresh)
if spt.refreshWithin != 2*defaultRefresh {
t.Fatal("adal: ServicePrincipalToken#SetRefreshWithin did not set the refresh interval")
}
}
func TestServicePrincipalTokenSetSender(t *testing.T) {
spt := newServicePrincipalToken()
c := &http.Client{}
spt.SetSender(c)
if !reflect.DeepEqual(c, spt.sender) {
t.Fatal("adal: ServicePrincipalToken#SetSender did not set the sender")
}
}
func TestServicePrincipalTokenRefreshUsesPOST(t *testing.T) {
spt := newServicePrincipalToken()
body := mocks.NewBody(newTokenJSON("test", "test"))
resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
if r.Method != "POST" {
t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set HTTP method -- expected %v, received %v", "POST", r.Method)
}
return resp, nil
})
}
})())
spt.SetSender(s)
err := spt.Refresh()
if err != nil {
t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
}
if body.IsOpen() {
t.Fatalf("the response was not closed!")
}
}
func TestServicePrincipalTokenFromMSIRefreshUsesPOST(t *testing.T) {
resource := "https://resource"
cb := func(token Token) error { return nil }
spt, err := NewServicePrincipalTokenFromMSI("http://msiendpoint/", resource, cb)
if err != nil {
t.Fatalf("Failed to get MSI SPT: %v", err)
}
body := mocks.NewBody(newTokenJSON("test", "test"))
resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
if r.Method != "POST" {
t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set HTTP method -- expected %v, received %v", "POST", r.Method)
}
if h := r.Header.Get("Metadata"); h != "true" {
t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set Metadata header for MSI")
}
return resp, nil
})
}
})())
spt.SetSender(s)
err = spt.Refresh()
if err != nil {
t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
}
if body.IsOpen() {
t.Fatalf("the response was not closed!")
}
}
func TestServicePrincipalTokenRefreshSetsMimeType(t *testing.T) {
spt := newServicePrincipalToken()
body := mocks.NewBody(newTokenJSON("test", "test"))
resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
if r.Header.Get(http.CanonicalHeaderKey("Content-Type")) != "application/x-www-form-urlencoded" {
t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set Content-Type -- expected %v, received %v",
"application/x-form-urlencoded",
r.Header.Get(http.CanonicalHeaderKey("Content-Type")))
}
return resp, nil
})
}
})())
spt.SetSender(s)
err := spt.Refresh()
if err != nil {
t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
}
}
func TestServicePrincipalTokenRefreshSetsURL(t *testing.T) {
spt := newServicePrincipalToken()
body := mocks.NewBody(newTokenJSON("test", "test"))
resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.String() != TestOAuthConfig.TokenEndpoint.String() {
t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set the URL -- expected %v, received %v",
TestOAuthConfig.TokenEndpoint, r.URL)
}
return resp, nil
})
}
})())
spt.SetSender(s)
err := spt.Refresh()
if err != nil {
t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
}
}
func testServicePrincipalTokenRefreshSetsBody(t *testing.T, spt *ServicePrincipalToken, f func(*testing.T, []byte)) {
body := mocks.NewBody(newTokenJSON("test", "test"))
resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
b, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("adal: Failed to read body of Service Principal token request (%v)", err)
}
f(t, b)
return resp, nil
})
}
})())
spt.SetSender(s)
err := spt.Refresh()
if err != nil {
t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
}
}
func TestServicePrincipalTokenManualRefreshSetsBody(t *testing.T) {
sptManual := newServicePrincipalTokenManual()
testServicePrincipalTokenRefreshSetsBody(t, sptManual, func(t *testing.T, b []byte) {
if string(b) != defaultManualFormData {
t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set the HTTP Request Body -- expected %v, received %v",
defaultManualFormData, string(b))
}
})
}
func TestServicePrincipalTokenCertficateRefreshSetsBody(t *testing.T) {
sptCert := newServicePrincipalTokenCertificate(t)
testServicePrincipalTokenRefreshSetsBody(t, sptCert, func(t *testing.T, b []byte) {
body := string(b)
values, _ := url.ParseQuery(body)
if values["client_assertion_type"][0] != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" ||
values["client_id"][0] != "id" ||
values["grant_type"][0] != "client_credentials" ||
values["resource"][0] != "resource" {
t.Fatalf("adal: ServicePrincipalTokenCertificate#Refresh did not correctly set the HTTP Request Body.")
}
})
}
func TestServicePrincipalTokenSecretRefreshSetsBody(t *testing.T) {
spt := newServicePrincipalToken()
testServicePrincipalTokenRefreshSetsBody(t, spt, func(t *testing.T, b []byte) {
if string(b) != defaultFormData {
t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set the HTTP Request Body -- expected %v, received %v",
defaultFormData, string(b))
}
})
}
func TestServicePrincipalTokenRefreshClosesRequestBody(t *testing.T) {
spt := newServicePrincipalToken()
body := mocks.NewBody(newTokenJSON("test", "test"))
resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
return resp, nil
})
}
})())
spt.SetSender(s)
err := spt.Refresh()
if err != nil {
t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
}
if resp.Body.(*mocks.Body).IsOpen() {
t.Fatal("adal: ServicePrincipalToken#Refresh failed to close the HTTP Response Body")
}
}
func TestServicePrincipalTokenRefreshRejectsResponsesWithStatusNotOK(t *testing.T) {
spt := newServicePrincipalToken()
body := mocks.NewBody(newTokenJSON("test", "test"))
resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusUnauthorized, "Unauthorized")
c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
return resp, nil
})
}
})())
spt.SetSender(s)
err := spt.Refresh()
if err == nil {
t.Fatalf("adal: ServicePrincipalToken#Refresh should reject a response with status != %d", http.StatusOK)
}
}
func TestServicePrincipalTokenRefreshRejectsEmptyBody(t *testing.T) {
spt := newServicePrincipalToken()
c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
return mocks.NewResponse(), nil
})
}
})())
spt.SetSender(s)
err := spt.Refresh()
if err == nil {
t.Fatal("adal: ServicePrincipalToken#Refresh should reject an empty token")
}
}
func TestServicePrincipalTokenRefreshPropagatesErrors(t *testing.T) {
spt := newServicePrincipalToken()
c := mocks.NewSender()
c.SetError(fmt.Errorf("Faux Error"))
spt.SetSender(c)
err := spt.Refresh()
if err == nil {
t.Fatal("adal: Failed to propagate the request error")
}
}
func TestServicePrincipalTokenRefreshReturnsErrorIfNotOk(t *testing.T) {
spt := newServicePrincipalToken()
c := mocks.NewSender()
c.AppendResponse(mocks.NewResponseWithStatus("401 NotAuthorized", http.StatusUnauthorized))
spt.SetSender(c)
err := spt.Refresh()
if err == nil {
t.Fatalf("adal: Failed to return an when receiving a status code other than HTTP %d", http.StatusOK)
}
}
func TestServicePrincipalTokenRefreshUnmarshals(t *testing.T) {
spt := newServicePrincipalToken()
expiresOn := strconv.Itoa(int(time.Now().Add(3600 * time.Second).Sub(date.UnixEpoch()).Seconds()))
j := newTokenJSON(expiresOn, "resource")
resp := mocks.NewResponseWithContent(j)
c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
return resp, nil
})
}
})())
spt.SetSender(s)
err := spt.Refresh()
if err != nil {
t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
} else if spt.AccessToken != "accessToken" ||
spt.ExpiresIn != "3600" ||
spt.ExpiresOn != expiresOn ||
spt.NotBefore != expiresOn ||
spt.Resource != "resource" ||
spt.Type != "Bearer" {
t.Fatalf("adal: ServicePrincipalToken#Refresh failed correctly unmarshal the JSON -- expected %v, received %v",
j, *spt)
}
}
func TestServicePrincipalTokenEnsureFreshRefreshes(t *testing.T) {
spt := newServicePrincipalToken()
expireToken(&spt.Token)
body := mocks.NewBody(newTokenJSON("test", "test"))
resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
f := false
c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
f = true
return resp, nil
})
}
})())
spt.SetSender(s)
err := spt.EnsureFresh()
if err != nil {
t.Fatalf("adal: ServicePrincipalToken#EnsureFresh returned an unexpected error (%v)", err)
}
if !f {
t.Fatal("adal: ServicePrincipalToken#EnsureFresh failed to call Refresh for stale token")
}
}
func TestServicePrincipalTokenEnsureFreshSkipsIfFresh(t *testing.T) {
spt := newServicePrincipalToken()
setTokenToExpireIn(&spt.Token, 1000*time.Second)
f := false
c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
f = true
return mocks.NewResponse(), nil
})
}
})())
spt.SetSender(s)
err := spt.EnsureFresh()
if err != nil {
t.Fatalf("adal: ServicePrincipalToken#EnsureFresh returned an unexpected error (%v)", err)
}
if f {
t.Fatal("adal: ServicePrincipalToken#EnsureFresh invoked Refresh for fresh token")
}
}
func TestRefreshCallback(t *testing.T) {
callbackTriggered := false
spt := newServicePrincipalToken(func(Token) error {
callbackTriggered = true
return nil
})
expiresOn := strconv.Itoa(int(time.Now().Add(3600 * time.Second).Sub(date.UnixEpoch()).Seconds()))
sender := mocks.NewSender()
j := newTokenJSON(expiresOn, "resource")
sender.AppendResponse(mocks.NewResponseWithContent(j))
spt.SetSender(sender)
err := spt.Refresh()
if err != nil {
t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
}
if !callbackTriggered {
t.Fatalf("adal: RefreshCallback failed to trigger call callback")
}
}
func TestRefreshCallbackErrorPropagates(t *testing.T) {
errorText := "this is an error text"
spt := newServicePrincipalToken(func(Token) error {
return fmt.Errorf(errorText)
})
expiresOn := strconv.Itoa(int(time.Now().Add(3600 * time.Second).Sub(date.UnixEpoch()).Seconds()))
sender := mocks.NewSender()
j := newTokenJSON(expiresOn, "resource")
sender.AppendResponse(mocks.NewResponseWithContent(j))
spt.SetSender(sender)
err := spt.Refresh()
if err == nil || !strings.Contains(err.Error(), errorText) {
t.Fatalf("adal: RefreshCallback failed to propagate error")
}
}
// This demonstrates the danger of manual token without a refresh token
func TestServicePrincipalTokenManualRefreshFailsWithoutRefresh(t *testing.T) {
spt := newServicePrincipalTokenManual()
spt.RefreshToken = ""
err := spt.Refresh()
if err == nil {
t.Fatalf("adal: ServicePrincipalToken#Refresh should have failed with a ManualTokenSecret without a refresh token")
}
}
func TestNewServicePrincipalTokenFromMSI(t *testing.T) {
resource := "https://resource"
cb := func(token Token) error { return nil }
spt, err := NewServicePrincipalTokenFromMSI("http://msiendpoint/", resource, cb)
if err != nil {
t.Fatalf("Failed to get MSI SPT: %v", err)
}
// check some of the SPT fields
if _, ok := spt.secret.(*ServicePrincipalMSISecret); !ok {
t.Fatal("SPT secret was not of MSI type")
}
if spt.resource != resource {
t.Fatal("SPT came back with incorrect resource")
}
if len(spt.refreshCallbacks) != 1 {
t.Fatal("SPT had incorrect refresh callbacks.")
}
}
func TestGetVMEndpoint(t *testing.T) {
tempSettingsFile, err := ioutil.TempFile("", "ManagedIdentity-Settings")
if err != nil {
t.Fatal("Couldn't write temp settings file")
}
defer os.Remove(tempSettingsFile.Name())
settingsContents := []byte(`{
"url": "http://msiendpoint/"
}`)
if _, err := tempSettingsFile.Write(settingsContents); err != nil {
t.Fatal("Couldn't fill temp settings file")
}
endpoint, err := getMSIVMEndpoint(tempSettingsFile.Name())
if err != nil {
t.Fatal("Coudn't get VM endpoint")
}
if endpoint != "http://msiendpoint/" {
t.Fatal("Didn't get correct endpoint")
}
}
func newToken() *Token {
return &Token{
AccessToken: "ASECRETVALUE",
Resource: "https://azure.microsoft.com/",
Type: "Bearer",
}
}
func newTokenJSON(expiresOn string, resource string) string {
return fmt.Sprintf(`{
"access_token" : "accessToken",
"expires_in" : "3600",
"expires_on" : "%s",
"not_before" : "%s",
"resource" : "%s",
"token_type" : "Bearer"
}`,
expiresOn, expiresOn, resource)
}
func newTokenExpiresIn(expireIn time.Duration) *Token {
return setTokenToExpireIn(newToken(), expireIn)
}
func newTokenExpiresAt(expireAt time.Time) *Token {
return setTokenToExpireAt(newToken(), expireAt)
}
func expireToken(t *Token) *Token {
return setTokenToExpireIn(t, 0)
}
func setTokenToExpireAt(t *Token, expireAt time.Time) *Token {
t.ExpiresIn = "3600"
t.ExpiresOn = strconv.Itoa(int(expireAt.Sub(date.UnixEpoch()).Seconds()))
t.NotBefore = t.ExpiresOn
return t
}
func setTokenToExpireIn(t *Token, expireIn time.Duration) *Token {
return setTokenToExpireAt(t, time.Now().Add(expireIn))
}
func newServicePrincipalToken(callbacks ...TokenRefreshCallback) *ServicePrincipalToken {
spt, _ := NewServicePrincipalToken(TestOAuthConfig, "id", "secret", "resource", callbacks...)
return spt
}
func newServicePrincipalTokenManual() *ServicePrincipalToken {
token := newToken()
token.RefreshToken = "refreshtoken"
spt, _ := NewServicePrincipalTokenFromManualToken(TestOAuthConfig, "id", "resource", *token)
return spt
}
func newServicePrincipalTokenCertificate(t *testing.T) *ServicePrincipalToken {
template := x509.Certificate{
SerialNumber: big.NewInt(0),
Subject: pkix.Name{CommonName: "test"},
BasicConstraintsValid: true,
}
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}
certificateBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
if err != nil {
t.Fatal(err)
}
certificate, err := x509.ParseCertificate(certificateBytes)
if err != nil {
t.Fatal(err)
}
spt, _ := NewServicePrincipalTokenFromCertificate(TestOAuthConfig, "id", certificate, privateKey, "resource")
return spt
}

View file

@ -24,9 +24,12 @@ import (
)
const (
bearerChallengeHeader = "Www-Authenticate"
bearer = "Bearer"
tenantID = "tenantID"
bearerChallengeHeader = "Www-Authenticate"
bearer = "Bearer"
tenantID = "tenantID"
apiKeyAuthorizerHeader = "Ocp-Apim-Subscription-Key"
bingAPISdkHeader = "X-BingApis-SDK-Client"
golangBingAPISdkHeaderValue = "Go-SDK"
)
// Authorizer is the interface that provides a PrepareDecorator used to supply request
@ -44,6 +47,53 @@ func (na NullAuthorizer) WithAuthorization() PrepareDecorator {
return WithNothing()
}
// APIKeyAuthorizer implements API Key authorization.
type APIKeyAuthorizer struct {
headers map[string]interface{}
queryParameters map[string]interface{}
}
// NewAPIKeyAuthorizerWithHeaders creates an ApiKeyAuthorizer with headers.
func NewAPIKeyAuthorizerWithHeaders(headers map[string]interface{}) *APIKeyAuthorizer {
return NewAPIKeyAuthorizer(headers, nil)
}
// NewAPIKeyAuthorizerWithQueryParameters creates an ApiKeyAuthorizer with query parameters.
func NewAPIKeyAuthorizerWithQueryParameters(queryParameters map[string]interface{}) *APIKeyAuthorizer {
return NewAPIKeyAuthorizer(nil, queryParameters)
}
// NewAPIKeyAuthorizer creates an ApiKeyAuthorizer with headers.
func NewAPIKeyAuthorizer(headers map[string]interface{}, queryParameters map[string]interface{}) *APIKeyAuthorizer {
return &APIKeyAuthorizer{headers: headers, queryParameters: queryParameters}
}
// WithAuthorization returns a PrepareDecorator that adds an HTTP headers and Query Paramaters
func (aka *APIKeyAuthorizer) WithAuthorization() PrepareDecorator {
return func(p Preparer) Preparer {
return DecoratePreparer(p, WithHeaders(aka.headers), WithQueryParameters(aka.queryParameters))
}
}
// CognitiveServicesAuthorizer implements authorization for Cognitive Services.
type CognitiveServicesAuthorizer struct {
subscriptionKey string
}
// NewCognitiveServicesAuthorizer is
func NewCognitiveServicesAuthorizer(subscriptionKey string) *CognitiveServicesAuthorizer {
return &CognitiveServicesAuthorizer{subscriptionKey: subscriptionKey}
}
// WithAuthorization is
func (csa *CognitiveServicesAuthorizer) WithAuthorization() PrepareDecorator {
headers := make(map[string]interface{})
headers[apiKeyAuthorizerHeader] = csa.subscriptionKey
headers[bingAPISdkHeader] = golangBingAPISdkHeaderValue
return NewAPIKeyAuthorizerWithHeaders(headers).WithAuthorization()
}
// BearerAuthorizer implements the bearer authorization
type BearerAuthorizer struct {
tokenProvider adal.OAuthTokenProvider
@ -69,7 +119,11 @@ func (ba *BearerAuthorizer) WithAuthorization() PrepareDecorator {
if ok {
err := refresher.EnsureFresh()
if err != nil {
return r, NewErrorWithError(err, "azure.BearerAuthorizer", "WithAuthorization", nil,
var resp *http.Response
if tokError, ok := err.(adal.TokenRefreshError); ok {
resp = tokError.Response()
}
return r, NewErrorWithError(err, "azure.BearerAuthorizer", "WithAuthorization", resp,
"Failed to refresh the Token for request to %s", r.URL)
}
}
@ -179,3 +233,22 @@ func newBearerChallenge(resp *http.Response) (bc bearerChallenge, err error) {
return bc, err
}
// EventGridKeyAuthorizer implements authorization for event grid using key authentication.
type EventGridKeyAuthorizer struct {
topicKey string
}
// NewEventGridKeyAuthorizer creates a new EventGridKeyAuthorizer
// with the specified topic key.
func NewEventGridKeyAuthorizer(topicKey string) EventGridKeyAuthorizer {
return EventGridKeyAuthorizer{topicKey: topicKey}
}
// WithAuthorization returns a PrepareDecorator that adds the aeg-sas-key authentication header.
func (egta EventGridKeyAuthorizer) WithAuthorization() PrepareDecorator {
headers := map[string]interface{}{
"aeg-sas-key": egta.topicKey,
}
return NewAPIKeyAuthorizerWithHeaders(headers).WithAuthorization()
}

View file

@ -1,188 +0,0 @@
package autorest
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"fmt"
"net/http"
"reflect"
"testing"
"github.com/Azure/go-autorest/autorest/adal"
"github.com/Azure/go-autorest/autorest/mocks"
)
const (
TestTenantID = "TestTenantID"
TestActiveDirectoryEndpoint = "https://login/test.com/"
)
func TestWithAuthorizer(t *testing.T) {
r1 := mocks.NewRequest()
na := &NullAuthorizer{}
r2, err := Prepare(r1,
na.WithAuthorization())
if err != nil {
t.Fatalf("autorest: NullAuthorizer#WithAuthorization returned an unexpected error (%v)", err)
} else if !reflect.DeepEqual(r1, r2) {
t.Fatalf("autorest: NullAuthorizer#WithAuthorization modified the request -- received %v, expected %v", r2, r1)
}
}
func TestTokenWithAuthorization(t *testing.T) {
token := &adal.Token{
AccessToken: "TestToken",
Resource: "https://azure.microsoft.com/",
Type: "Bearer",
}
ba := NewBearerAuthorizer(token)
req, err := Prepare(&http.Request{}, ba.WithAuthorization())
if err != nil {
t.Fatalf("azure: BearerAuthorizer#WithAuthorization returned an error (%v)", err)
} else if req.Header.Get(http.CanonicalHeaderKey("Authorization")) != fmt.Sprintf("Bearer %s", token.AccessToken) {
t.Fatal("azure: BearerAuthorizer#WithAuthorization failed to set Authorization header")
}
}
func TestServicePrincipalTokenWithAuthorizationNoRefresh(t *testing.T) {
oauthConfig, err := adal.NewOAuthConfig(TestActiveDirectoryEndpoint, TestTenantID)
if err != nil {
t.Fatalf("azure: BearerAuthorizer#WithAuthorization returned an error (%v)", err)
}
spt, err := adal.NewServicePrincipalToken(*oauthConfig, "id", "secret", "resource", nil)
if err != nil {
t.Fatalf("azure: BearerAuthorizer#WithAuthorization returned an error (%v)", err)
}
spt.SetAutoRefresh(false)
s := mocks.NewSender()
spt.SetSender(s)
ba := NewBearerAuthorizer(spt)
req, err := Prepare(mocks.NewRequest(), ba.WithAuthorization())
if err != nil {
t.Fatalf("azure: BearerAuthorizer#WithAuthorization returned an error (%v)", err)
} else if req.Header.Get(http.CanonicalHeaderKey("Authorization")) != fmt.Sprintf("Bearer %s", spt.AccessToken) {
t.Fatal("azure: BearerAuthorizer#WithAuthorization failed to set Authorization header")
}
}
func TestServicePrincipalTokenWithAuthorizationRefresh(t *testing.T) {
oauthConfig, err := adal.NewOAuthConfig(TestActiveDirectoryEndpoint, TestTenantID)
if err != nil {
t.Fatalf("azure: BearerAuthorizer#WithAuthorization returned an error (%v)", err)
}
refreshed := false
spt, err := adal.NewServicePrincipalToken(*oauthConfig, "id", "secret", "resource", func(t adal.Token) error {
refreshed = true
return nil
})
if err != nil {
t.Fatalf("azure: BearerAuthorizer#WithAuthorization returned an error (%v)", err)
}
jwt := `{
"access_token" : "accessToken",
"expires_in" : "3600",
"expires_on" : "test",
"not_before" : "test",
"resource" : "test",
"token_type" : "Bearer"
}`
body := mocks.NewBody(jwt)
resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
return resp, nil
})
}
})())
spt.SetSender(s)
ba := NewBearerAuthorizer(spt)
req, err := Prepare(mocks.NewRequest(), ba.WithAuthorization())
if err != nil {
t.Fatalf("azure: BearerAuthorizer#WithAuthorization returned an error (%v)", err)
} else if req.Header.Get(http.CanonicalHeaderKey("Authorization")) != fmt.Sprintf("Bearer %s", spt.AccessToken) {
t.Fatal("azure: BearerAuthorizer#WithAuthorization failed to set Authorization header")
}
if !refreshed {
t.Fatal("azure: BearerAuthorizer#WithAuthorization must refresh the token")
}
}
func TestServicePrincipalTokenWithAuthorizationReturnsErrorIfConnotRefresh(t *testing.T) {
oauthConfig, err := adal.NewOAuthConfig(TestActiveDirectoryEndpoint, TestTenantID)
if err != nil {
t.Fatalf("azure: BearerAuthorizer#WithAuthorization returned an error (%v)", err)
}
spt, err := adal.NewServicePrincipalToken(*oauthConfig, "id", "secret", "resource", nil)
if err != nil {
t.Fatalf("azure: BearerAuthorizer#WithAuthorization returned an error (%v)", err)
}
s := mocks.NewSender()
s.AppendResponse(mocks.NewResponseWithStatus("400 Bad Request", http.StatusBadRequest))
spt.SetSender(s)
ba := NewBearerAuthorizer(spt)
_, err = Prepare(mocks.NewRequest(), ba.WithAuthorization())
if err == nil {
t.Fatal("azure: BearerAuthorizer#WithAuthorization failed to return an error when refresh fails")
}
}
func TestBearerAuthorizerCallback(t *testing.T) {
tenantString := "123-tenantID-456"
resourceString := "https://fake.resource.net"
s := mocks.NewSender()
resp := mocks.NewResponseWithStatus("401 Unauthorized", http.StatusUnauthorized)
mocks.SetResponseHeader(resp, bearerChallengeHeader, bearer+" \"authorization\"=\"https://fake.net/"+tenantString+"\",\"resource\"=\""+resourceString+"\"")
s.AppendResponse(resp)
auth := NewBearerAuthorizerCallback(s, func(tenantID, resource string) (*BearerAuthorizer, error) {
if tenantID != tenantString {
t.Fatal("BearerAuthorizerCallback: bad tenant ID")
}
if resource != resourceString {
t.Fatal("BearerAuthorizerCallback: bad resource")
}
oauthConfig, err := adal.NewOAuthConfig(TestActiveDirectoryEndpoint, tenantID)
if err != nil {
t.Fatalf("azure: NewOAuthConfig returned an error (%v)", err)
}
spt, err := adal.NewServicePrincipalToken(*oauthConfig, "id", "secret", resource)
if err != nil {
t.Fatalf("azure: NewServicePrincipalToken returned an error (%v)", err)
}
spt.SetSender(s)
return NewBearerAuthorizer(spt), nil
})
_, err := Prepare(mocks.NewRequest(), auth.WithAuthorization())
if err == nil {
t.Fatal("azure: BearerAuthorizerCallback#WithAuthorization failed to return an error when refresh fails")
}
}

View file

@ -1,140 +0,0 @@
package autorest
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"net/http"
"testing"
"github.com/Azure/go-autorest/autorest/mocks"
)
func TestResponseHasStatusCode(t *testing.T) {
codes := []int{http.StatusOK, http.StatusAccepted}
resp := &http.Response{StatusCode: http.StatusAccepted}
if !ResponseHasStatusCode(resp, codes...) {
t.Fatalf("autorest: ResponseHasStatusCode failed to find %v in %v", resp.StatusCode, codes)
}
}
func TestResponseHasStatusCodeNotPresent(t *testing.T) {
codes := []int{http.StatusOK, http.StatusAccepted}
resp := &http.Response{StatusCode: http.StatusInternalServerError}
if ResponseHasStatusCode(resp, codes...) {
t.Fatalf("autorest: ResponseHasStatusCode unexpectedly found %v in %v", resp.StatusCode, codes)
}
}
func TestNewPollingRequestDoesNotReturnARequestWhenLocationHeaderIsMissing(t *testing.T) {
resp := mocks.NewResponseWithStatus("500 InternalServerError", http.StatusInternalServerError)
req, _ := NewPollingRequest(resp, nil)
if req != nil {
t.Fatal("autorest: NewPollingRequest returned an http.Request when the Location header was missing")
}
}
func TestNewPollingRequestReturnsAnErrorWhenPrepareFails(t *testing.T) {
resp := mocks.NewResponseWithStatus("202 Accepted", http.StatusAccepted)
mocks.SetAcceptedHeaders(resp)
resp.Header.Set(http.CanonicalHeaderKey(HeaderLocation), mocks.TestBadURL)
_, err := NewPollingRequest(resp, nil)
if err == nil {
t.Fatal("autorest: NewPollingRequest failed to return an error when Prepare fails")
}
}
func TestNewPollingRequestDoesNotReturnARequestWhenPrepareFails(t *testing.T) {
resp := mocks.NewResponseWithStatus("202 Accepted", http.StatusAccepted)
mocks.SetAcceptedHeaders(resp)
resp.Header.Set(http.CanonicalHeaderKey(HeaderLocation), mocks.TestBadURL)
req, _ := NewPollingRequest(resp, nil)
if req != nil {
t.Fatal("autorest: NewPollingRequest returned an http.Request when Prepare failed")
}
}
func TestNewPollingRequestReturnsAGetRequest(t *testing.T) {
resp := mocks.NewResponseWithStatus("202 Accepted", http.StatusAccepted)
mocks.SetAcceptedHeaders(resp)
req, _ := NewPollingRequest(resp, nil)
if req.Method != "GET" {
t.Fatalf("autorest: NewPollingRequest did not create an HTTP GET request -- actual method %v", req.Method)
}
}
func TestNewPollingRequestProvidesTheURL(t *testing.T) {
resp := mocks.NewResponseWithStatus("202 Accepted", http.StatusAccepted)
mocks.SetAcceptedHeaders(resp)
req, _ := NewPollingRequest(resp, nil)
if req.URL.String() != mocks.TestURL {
t.Fatalf("autorest: NewPollingRequest did not create an HTTP with the expected URL -- received %v, expected %v", req.URL, mocks.TestURL)
}
}
func TestGetLocation(t *testing.T) {
resp := mocks.NewResponseWithStatus("202 Accepted", http.StatusAccepted)
mocks.SetAcceptedHeaders(resp)
l := GetLocation(resp)
if len(l) == 0 {
t.Fatalf("autorest: GetLocation failed to return Location header -- expected %v, received %v", mocks.TestURL, l)
}
}
func TestGetLocationReturnsEmptyStringForMissingLocation(t *testing.T) {
resp := mocks.NewResponseWithStatus("202 Accepted", http.StatusAccepted)
l := GetLocation(resp)
if len(l) != 0 {
t.Fatalf("autorest: GetLocation return a value without a Location header -- received %v", l)
}
}
func TestGetRetryAfter(t *testing.T) {
resp := mocks.NewResponseWithStatus("202 Accepted", http.StatusAccepted)
mocks.SetAcceptedHeaders(resp)
d := GetRetryAfter(resp, DefaultPollingDelay)
if d != mocks.TestDelay {
t.Fatalf("autorest: GetRetryAfter failed to returned the expected delay -- expected %v, received %v", mocks.TestDelay, d)
}
}
func TestGetRetryAfterReturnsDefaultDelayIfRetryHeaderIsMissing(t *testing.T) {
resp := mocks.NewResponseWithStatus("202 Accepted", http.StatusAccepted)
d := GetRetryAfter(resp, DefaultPollingDelay)
if d != DefaultPollingDelay {
t.Fatalf("autorest: GetRetryAfter failed to returned the default delay for a missing Retry-After header -- expected %v, received %v",
DefaultPollingDelay, d)
}
}
func TestGetRetryAfterReturnsDefaultDelayIfRetryHeaderIsMalformed(t *testing.T) {
resp := mocks.NewResponseWithStatus("202 Accepted", http.StatusAccepted)
mocks.SetAcceptedHeaders(resp)
resp.Header.Set(http.CanonicalHeaderKey(HeaderRetryAfter), "a very bad non-integer value")
d := GetRetryAfter(resp, DefaultPollingDelay)
if d != DefaultPollingDelay {
t.Fatalf("autorest: GetRetryAfter failed to returned the default delay for a malformed Retry-After header -- expected %v, received %v",
DefaultPollingDelay, d)
}
}

View file

@ -39,7 +39,7 @@ const (
operationSucceeded string = "Succeeded"
)
var pollingCodes = [...]int{http.StatusAccepted, http.StatusCreated, http.StatusOK}
var pollingCodes = [...]int{http.StatusNoContent, http.StatusAccepted, http.StatusCreated, http.StatusOK}
// Future provides a mechanism to access the status and results of an asynchronous request.
// Since futures are stateful they should be passed by value to avoid race conditions.
@ -234,20 +234,15 @@ func getAsyncOperation(resp *http.Response) string {
}
func hasSucceeded(state string) bool {
return state == operationSucceeded
return strings.EqualFold(state, operationSucceeded)
}
func hasTerminated(state string) bool {
switch state {
case operationCanceled, operationFailed, operationSucceeded:
return true
default:
return false
}
return strings.EqualFold(state, operationCanceled) || strings.EqualFold(state, operationFailed) || strings.EqualFold(state, operationSucceeded)
}
func hasFailed(state string) bool {
return state == operationFailed
return strings.EqualFold(state, operationFailed)
}
type provisioningTracker interface {
@ -426,7 +421,7 @@ func updatePollingState(resp *http.Response, ps *pollingState) error {
}
}
if ps.State == operationInProgress && ps.URI == "" {
if strings.EqualFold(ps.State, operationInProgress) && ps.URI == "" {
return autorest.NewError("azure", "updatePollingState", "Azure Polling Error - Unable to obtain polling URI for %s %s", resp.Request.Method, resp.Request.URL)
}
@ -463,3 +458,21 @@ func newPollingRequest(ps pollingState) (*http.Request, error) {
return reqPoll, nil
}
// AsyncOpIncompleteError is the type that's returned from a future that has not completed.
type AsyncOpIncompleteError struct {
// FutureType is the name of the type composed of a azure.Future.
FutureType string
}
// Error returns an error message including the originating type name of the error.
func (e AsyncOpIncompleteError) Error() string {
return fmt.Sprintf("%s: asynchronous operation has not completed", e.FutureType)
}
// NewAsyncOpIncompleteError creates a new AsyncOpIncompleteError with the specified parameters.
func NewAsyncOpIncompleteError(futureType string) AsyncOpIncompleteError {
return AsyncOpIncompleteError{
FutureType: futureType,
}
}

File diff suppressed because it is too large Load diff

View file

@ -1,513 +0,0 @@
package azure
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"reflect"
"strconv"
"testing"
"time"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/mocks"
)
const (
headerAuthorization = "Authorization"
longDelay = 5 * time.Second
retryDelay = 10 * time.Millisecond
testLogPrefix = "azure:"
)
// Use a Client Inspector to set the request identifier.
func ExampleWithClientID() {
uuid := "71FDB9F4-5E49-4C12-B266-DE7B4FD999A6"
req, _ := autorest.Prepare(&http.Request{},
autorest.AsGet(),
autorest.WithBaseURL("https://microsoft.com/a/b/c/"))
c := autorest.Client{Sender: mocks.NewSender()}
c.RequestInspector = WithReturningClientID(uuid)
autorest.SendWithSender(c, req)
fmt.Printf("Inspector added the %s header with the value %s\n",
HeaderClientID, req.Header.Get(HeaderClientID))
fmt.Printf("Inspector added the %s header with the value %s\n",
HeaderReturnClientID, req.Header.Get(HeaderReturnClientID))
// Output:
// Inspector added the x-ms-client-request-id header with the value 71FDB9F4-5E49-4C12-B266-DE7B4FD999A6
// Inspector added the x-ms-return-client-request-id header with the value true
}
func TestWithReturningClientIDReturnsError(t *testing.T) {
var errIn error
uuid := "71FDB9F4-5E49-4C12-B266-DE7B4FD999A6"
_, errOut := autorest.Prepare(&http.Request{},
withErrorPrepareDecorator(&errIn),
WithReturningClientID(uuid))
if errOut == nil || errIn != errOut {
t.Fatalf("azure: WithReturningClientID failed to exit early when receiving an error -- expected (%v), received (%v)",
errIn, errOut)
}
}
func TestWithClientID(t *testing.T) {
uuid := "71FDB9F4-5E49-4C12-B266-DE7B4FD999A6"
req, _ := autorest.Prepare(&http.Request{},
WithClientID(uuid))
if req.Header.Get(HeaderClientID) != uuid {
t.Fatalf("azure: WithClientID failed to set %s -- expected %s, received %s",
HeaderClientID, uuid, req.Header.Get(HeaderClientID))
}
}
func TestWithReturnClientID(t *testing.T) {
b := false
req, _ := autorest.Prepare(&http.Request{},
WithReturnClientID(b))
if req.Header.Get(HeaderReturnClientID) != strconv.FormatBool(b) {
t.Fatalf("azure: WithReturnClientID failed to set %s -- expected %s, received %s",
HeaderClientID, strconv.FormatBool(b), req.Header.Get(HeaderClientID))
}
}
func TestExtractClientID(t *testing.T) {
uuid := "71FDB9F4-5E49-4C12-B266-DE7B4FD999A6"
resp := mocks.NewResponse()
mocks.SetResponseHeader(resp, HeaderClientID, uuid)
if ExtractClientID(resp) != uuid {
t.Fatalf("azure: ExtractClientID failed to extract the %s -- expected %s, received %s",
HeaderClientID, uuid, ExtractClientID(resp))
}
}
func TestExtractRequestID(t *testing.T) {
uuid := "71FDB9F4-5E49-4C12-B266-DE7B4FD999A6"
resp := mocks.NewResponse()
mocks.SetResponseHeader(resp, HeaderRequestID, uuid)
if ExtractRequestID(resp) != uuid {
t.Fatalf("azure: ExtractRequestID failed to extract the %s -- expected %s, received %s",
HeaderRequestID, uuid, ExtractRequestID(resp))
}
}
func TestIsAzureError_ReturnsTrueForAzureError(t *testing.T) {
if !IsAzureError(&RequestError{}) {
t.Fatalf("azure: IsAzureError failed to return true for an Azure Service error")
}
}
func TestIsAzureError_ReturnsFalseForNonAzureError(t *testing.T) {
if IsAzureError(fmt.Errorf("An Error")) {
t.Fatalf("azure: IsAzureError return true for an non-Azure Service error")
}
}
func TestNewErrorWithError_UsesReponseStatusCode(t *testing.T) {
e := NewErrorWithError(fmt.Errorf("Error"), "packageType", "method", mocks.NewResponseWithStatus("Forbidden", http.StatusForbidden), "message")
if e.StatusCode != http.StatusForbidden {
t.Fatalf("azure: NewErrorWithError failed to use the Status Code of the passed Response -- expected %v, received %v", http.StatusForbidden, e.StatusCode)
}
}
func TestNewErrorWithError_ReturnsUnwrappedError(t *testing.T) {
e1 := RequestError{}
e1.ServiceError = &ServiceError{Code: "42", Message: "A Message"}
e1.StatusCode = 200
e1.RequestID = "A RequestID"
e2 := NewErrorWithError(&e1, "packageType", "method", nil, "message")
if !reflect.DeepEqual(e1, e2) {
t.Fatalf("azure: NewErrorWithError wrapped an RequestError -- expected %T, received %T", e1, e2)
}
}
func TestNewErrorWithError_WrapsAnError(t *testing.T) {
e1 := fmt.Errorf("Inner Error")
var e2 interface{} = NewErrorWithError(e1, "packageType", "method", nil, "message")
if _, ok := e2.(RequestError); !ok {
t.Fatalf("azure: NewErrorWithError failed to wrap a standard error -- received %T", e2)
}
}
func TestWithErrorUnlessStatusCode_NotAnAzureError(t *testing.T) {
body := `<html>
<head>
<title>IIS Error page</title>
</head>
<body>Some non-JSON error page</body>
</html>`
r := mocks.NewResponseWithContent(body)
r.Request = mocks.NewRequest()
r.StatusCode = http.StatusBadRequest
r.Status = http.StatusText(r.StatusCode)
err := autorest.Respond(r,
WithErrorUnlessStatusCode(http.StatusOK),
autorest.ByClosing())
ok, _ := err.(*RequestError)
if ok != nil {
t.Fatalf("azure: azure.RequestError returned from malformed response: %v", err)
}
// the error body should still be there
defer r.Body.Close()
b, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatal(err)
}
if string(b) != body {
t.Fatalf("response body is wrong. got=%q exptected=%q", string(b), body)
}
}
func TestWithErrorUnlessStatusCode_FoundAzureErrorWithoutDetails(t *testing.T) {
j := `{
"error": {
"code": "InternalError",
"message": "Azure is having trouble right now."
}
}`
uuid := "71FDB9F4-5E49-4C12-B266-DE7B4FD999A6"
r := mocks.NewResponseWithContent(j)
mocks.SetResponseHeader(r, HeaderRequestID, uuid)
r.Request = mocks.NewRequest()
r.StatusCode = http.StatusInternalServerError
r.Status = http.StatusText(r.StatusCode)
err := autorest.Respond(r,
WithErrorUnlessStatusCode(http.StatusOK),
autorest.ByClosing())
if err == nil {
t.Fatalf("azure: returned nil error for proper error response")
}
azErr, ok := err.(*RequestError)
if !ok {
t.Fatalf("azure: returned error is not azure.RequestError: %T", err)
}
expected := "autorest/azure: Service returned an error. Status=500 Code=\"InternalError\" Message=\"Azure is having trouble right now.\""
if !reflect.DeepEqual(expected, azErr.Error()) {
t.Fatalf("azure: service error is not unmarshaled properly.\nexpected=%v\ngot=%v", expected, azErr.Error())
}
if expected := http.StatusInternalServerError; azErr.StatusCode != expected {
t.Fatalf("azure: got wrong StatusCode=%d Expected=%d", azErr.StatusCode, expected)
}
if expected := uuid; azErr.RequestID != expected {
t.Fatalf("azure: wrong request ID in error. expected=%q; got=%q", expected, azErr.RequestID)
}
_ = azErr.Error()
// the error body should still be there
defer r.Body.Close()
b, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatal(err)
}
if string(b) != j {
t.Fatalf("response body is wrong. got=%q expected=%q", string(b), j)
}
}
func TestWithErrorUnlessStatusCode_FoundAzureErrorWithDetails(t *testing.T) {
j := `{
"error": {
"code": "InternalError",
"message": "Azure is having trouble right now.",
"details": [{"code": "conflict1", "message":"error message1"},
{"code": "conflict2", "message":"error message2"}]
}
}`
uuid := "71FDB9F4-5E49-4C12-B266-DE7B4FD999A6"
r := mocks.NewResponseWithContent(j)
mocks.SetResponseHeader(r, HeaderRequestID, uuid)
r.Request = mocks.NewRequest()
r.StatusCode = http.StatusInternalServerError
r.Status = http.StatusText(r.StatusCode)
err := autorest.Respond(r,
WithErrorUnlessStatusCode(http.StatusOK),
autorest.ByClosing())
if err == nil {
t.Fatalf("azure: returned nil error for proper error response")
}
azErr, ok := err.(*RequestError)
if !ok {
t.Fatalf("azure: returned error is not azure.RequestError: %T", err)
}
if expected := "InternalError"; azErr.ServiceError.Code != expected {
t.Fatalf("azure: wrong error code. expected=%q; got=%q", expected, azErr.ServiceError.Code)
}
if azErr.ServiceError.Message == "" {
t.Fatalf("azure: error message is not unmarshaled properly")
}
b, _ := json.Marshal(*azErr.ServiceError.Details)
if string(b) != `[{"code":"conflict1","message":"error message1"},{"code":"conflict2","message":"error message2"}]` {
t.Fatalf("azure: error details is not unmarshaled properly")
}
if expected := http.StatusInternalServerError; azErr.StatusCode != expected {
t.Fatalf("azure: got wrong StatusCode=%v Expected=%d", azErr.StatusCode, expected)
}
if expected := uuid; azErr.RequestID != expected {
t.Fatalf("azure: wrong request ID in error. expected=%q; got=%q", expected, azErr.RequestID)
}
_ = azErr.Error()
// the error body should still be there
defer r.Body.Close()
b, err = ioutil.ReadAll(r.Body)
if err != nil {
t.Fatal(err)
}
if string(b) != j {
t.Fatalf("response body is wrong. got=%q expected=%q", string(b), j)
}
}
func TestWithErrorUnlessStatusCode_NoAzureError(t *testing.T) {
j := `{
"Status":"NotFound"
}`
uuid := "71FDB9F4-5E49-4C12-B266-DE7B4FD999A6"
r := mocks.NewResponseWithContent(j)
mocks.SetResponseHeader(r, HeaderRequestID, uuid)
r.Request = mocks.NewRequest()
r.StatusCode = http.StatusInternalServerError
r.Status = http.StatusText(r.StatusCode)
err := autorest.Respond(r,
WithErrorUnlessStatusCode(http.StatusOK),
autorest.ByClosing())
if err == nil {
t.Fatalf("azure: returned nil error for proper error response")
}
azErr, ok := err.(*RequestError)
if !ok {
t.Fatalf("azure: returned error is not azure.RequestError: %T", err)
}
expected := &ServiceError{
Code: "Unknown",
Message: "Unknown service error",
}
if !reflect.DeepEqual(expected, azErr.ServiceError) {
t.Fatalf("azure: service error is not unmarshaled properly. expected=%q\ngot=%q", expected, azErr.ServiceError)
}
if expected := http.StatusInternalServerError; azErr.StatusCode != expected {
t.Fatalf("azure: got wrong StatusCode=%v Expected=%d", azErr.StatusCode, expected)
}
if expected := uuid; azErr.RequestID != expected {
t.Fatalf("azure: wrong request ID in error. expected=%q; got=%q", expected, azErr.RequestID)
}
_ = azErr.Error()
// the error body should still be there
defer r.Body.Close()
b, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatal(err)
}
if string(b) != j {
t.Fatalf("response body is wrong. got=%q expected=%q", string(b), j)
}
}
func TestWithErrorUnlessStatusCode_UnwrappedError(t *testing.T) {
j := `{
"target": null,
"code": "InternalError",
"message": "Azure is having trouble right now.",
"details": [{"code": "conflict1", "message":"error message1"},
{"code": "conflict2", "message":"error message2"}],
"innererror": []
}`
uuid := "71FDB9F4-5E49-4C12-B266-DE7B4FD999A6"
r := mocks.NewResponseWithContent(j)
mocks.SetResponseHeader(r, HeaderRequestID, uuid)
r.Request = mocks.NewRequest()
r.StatusCode = http.StatusInternalServerError
r.Status = http.StatusText(r.StatusCode)
err := autorest.Respond(r,
WithErrorUnlessStatusCode(http.StatusOK),
autorest.ByClosing())
if err == nil {
t.Fatal("azure: returned nil error for proper error response")
}
azErr, ok := err.(*RequestError)
if !ok {
t.Fatalf("returned error is not azure.RequestError: %T", err)
}
if expected := http.StatusInternalServerError; azErr.StatusCode != expected {
t.Logf("Incorrect StatusCode got: %v want: %d", azErr.StatusCode, expected)
t.Fail()
}
if expected := "Azure is having trouble right now."; azErr.ServiceError.Message != expected {
t.Logf("Incorrect Message\n\tgot: %q\n\twant: %q", azErr.Message, expected)
t.Fail()
}
if expected := uuid; azErr.RequestID != expected {
t.Logf("Incorrect request ID\n\tgot: %q\n\twant: %q", azErr.RequestID, expected)
t.Fail()
}
expectedServiceErrorDetails := `[{"code":"conflict1","message":"error message1"},{"code":"conflict2","message":"error message2"}]`
if azErr.ServiceError == nil {
t.Logf("`ServiceError` was nil when it shouldn't have been.")
t.Fail()
} else if azErr.ServiceError.Details == nil {
t.Logf("`ServiceError.Details` was nil when it should have been %q", expectedServiceErrorDetails)
t.Fail()
} else if details, _ := json.Marshal(*azErr.ServiceError.Details); expectedServiceErrorDetails != string(details) {
t.Logf("Error detaisl was not unmarshaled properly.\n\tgot: %q\n\twant: %q", string(details), expectedServiceErrorDetails)
t.Fail()
}
// the error body should still be there
defer r.Body.Close()
b, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Error(err)
}
if string(b) != j {
t.Fatalf("response body is wrong. got=%q expected=%q", string(b), j)
}
}
func TestRequestErrorString_WithError(t *testing.T) {
j := `{
"error": {
"code": "InternalError",
"message": "Conflict",
"details": [{"code": "conflict1", "message":"error message1"}]
}
}`
uuid := "71FDB9F4-5E49-4C12-B266-DE7B4FD999A6"
r := mocks.NewResponseWithContent(j)
mocks.SetResponseHeader(r, HeaderRequestID, uuid)
r.Request = mocks.NewRequest()
r.StatusCode = http.StatusInternalServerError
r.Status = http.StatusText(r.StatusCode)
err := autorest.Respond(r,
WithErrorUnlessStatusCode(http.StatusOK),
autorest.ByClosing())
if err == nil {
t.Fatalf("azure: returned nil error for proper error response")
}
azErr, _ := err.(*RequestError)
expected := "autorest/azure: Service returned an error. Status=500 Code=\"InternalError\" Message=\"Conflict\" Details=[{\"code\":\"conflict1\",\"message\":\"error message1\"}]"
if expected != azErr.Error() {
t.Fatalf("azure: send wrong RequestError.\nexpected=%v\ngot=%v", expected, azErr.Error())
}
}
func withErrorPrepareDecorator(e *error) autorest.PrepareDecorator {
return func(p autorest.Preparer) autorest.Preparer {
return autorest.PreparerFunc(func(r *http.Request) (*http.Request, error) {
*e = fmt.Errorf("azure: Faux Prepare Error")
return r, *e
})
}
}
func withAsyncResponseDecorator(n int) autorest.SendDecorator {
i := 0
return func(s autorest.Sender) autorest.Sender {
return autorest.SenderFunc(func(r *http.Request) (*http.Response, error) {
resp, err := s.Do(r)
if err == nil {
if i < n {
resp.StatusCode = http.StatusCreated
resp.Header = http.Header{}
resp.Header.Add(http.CanonicalHeaderKey(headerAsyncOperation), mocks.TestURL)
i++
} else {
resp.StatusCode = http.StatusOK
resp.Header.Del(http.CanonicalHeaderKey(headerAsyncOperation))
}
}
return resp, err
})
}
}
type mockAuthorizer struct{}
func (ma mockAuthorizer) WithAuthorization() autorest.PrepareDecorator {
return autorest.WithHeader(headerAuthorization, mocks.TestAuthorizationHeader)
}
type mockFailingAuthorizer struct{}
func (mfa mockFailingAuthorizer) WithAuthorization() autorest.PrepareDecorator {
return func(p autorest.Preparer) autorest.Preparer {
return autorest.PreparerFunc(func(r *http.Request) (*http.Request, error) {
return r, fmt.Errorf("ERROR: mockFailingAuthorizer returned expected error")
})
}
}
type mockInspector struct {
wasInvoked bool
}
func (mi *mockInspector) WithInspection() autorest.PrepareDecorator {
return func(p autorest.Preparer) autorest.Preparer {
return autorest.PreparerFunc(func(r *http.Request) (*http.Request, error) {
mi.wasInvoked = true
return p.Prepare(r)
})
}
}
func (mi *mockInspector) ByInspecting() autorest.RespondDecorator {
return func(r autorest.Responder) autorest.Responder {
return autorest.ResponderFunc(func(resp *http.Response) error {
mi.wasInvoked = true
return r.Respond(resp)
})
}
}

View file

@ -83,10 +83,10 @@ var (
PublishSettingsURL: "https://manage.windowsazure.us/publishsettings/index",
ServiceManagementEndpoint: "https://management.core.usgovcloudapi.net/",
ResourceManagerEndpoint: "https://management.usgovcloudapi.net/",
ActiveDirectoryEndpoint: "https://login.microsoftonline.com/",
ActiveDirectoryEndpoint: "https://login.microsoftonline.us/",
GalleryEndpoint: "https://gallery.usgovcloudapi.net/",
KeyVaultEndpoint: "https://vault.usgovcloudapi.net/",
GraphEndpoint: "https://graph.usgovcloudapi.net/",
GraphEndpoint: "https://graph.windows.net/",
StorageEndpointSuffix: "core.usgovcloudapi.net",
SQLDatabaseDNSSuffix: "database.usgovcloudapi.net",
TrafficManagerDNSSuffix: "usgovtrafficmanager.net",

View file

@ -1,284 +0,0 @@
// test
package azure
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"encoding/json"
"os"
"path"
"path/filepath"
"runtime"
"testing"
)
// This correlates to the expected contents of ./testdata/test_environment_1.json
var testEnvironment1 = Environment{
Name: "--unit-test--",
ManagementPortalURL: "--management-portal-url",
PublishSettingsURL: "--publish-settings-url--",
ServiceManagementEndpoint: "--service-management-endpoint--",
ResourceManagerEndpoint: "--resource-management-endpoint--",
ActiveDirectoryEndpoint: "--active-directory-endpoint--",
GalleryEndpoint: "--gallery-endpoint--",
KeyVaultEndpoint: "--key-vault--endpoint--",
GraphEndpoint: "--graph-endpoint--",
StorageEndpointSuffix: "--storage-endpoint-suffix--",
SQLDatabaseDNSSuffix: "--sql-database-dns-suffix--",
TrafficManagerDNSSuffix: "--traffic-manager-dns-suffix--",
KeyVaultDNSSuffix: "--key-vault-dns-suffix--",
ServiceBusEndpointSuffix: "--service-bus-endpoint-suffix--",
ServiceManagementVMDNSSuffix: "--asm-vm-dns-suffix--",
ResourceManagerVMDNSSuffix: "--arm-vm-dns-suffix--",
ContainerRegistryDNSSuffix: "--container-registry-dns-suffix--",
}
func TestEnvironment_EnvironmentFromFile(t *testing.T) {
got, err := EnvironmentFromFile(filepath.Join("testdata", "test_environment_1.json"))
if err != nil {
t.Error(err)
}
if got != testEnvironment1 {
t.Logf("got: %v want: %v", got, testEnvironment1)
t.Fail()
}
}
func TestEnvironment_EnvironmentFromName_Stack(t *testing.T) {
_, currentFile, _, _ := runtime.Caller(0)
prevEnvFilepathValue := os.Getenv(EnvironmentFilepathName)
os.Setenv(EnvironmentFilepathName, filepath.Join(path.Dir(currentFile), "testdata", "test_environment_1.json"))
defer os.Setenv(EnvironmentFilepathName, prevEnvFilepathValue)
got, err := EnvironmentFromName("AZURESTACKCLOUD")
if err != nil {
t.Error(err)
}
if got != testEnvironment1 {
t.Logf("got: %v want: %v", got, testEnvironment1)
t.Fail()
}
}
func TestEnvironmentFromName(t *testing.T) {
name := "azurechinacloud"
if env, _ := EnvironmentFromName(name); env != ChinaCloud {
t.Errorf("Expected to get ChinaCloud for %q", name)
}
name = "AzureChinaCloud"
if env, _ := EnvironmentFromName(name); env != ChinaCloud {
t.Errorf("Expected to get ChinaCloud for %q", name)
}
name = "azuregermancloud"
if env, _ := EnvironmentFromName(name); env != GermanCloud {
t.Errorf("Expected to get GermanCloud for %q", name)
}
name = "AzureGermanCloud"
if env, _ := EnvironmentFromName(name); env != GermanCloud {
t.Errorf("Expected to get GermanCloud for %q", name)
}
name = "azurepubliccloud"
if env, _ := EnvironmentFromName(name); env != PublicCloud {
t.Errorf("Expected to get PublicCloud for %q", name)
}
name = "AzurePublicCloud"
if env, _ := EnvironmentFromName(name); env != PublicCloud {
t.Errorf("Expected to get PublicCloud for %q", name)
}
name = "azureusgovernmentcloud"
if env, _ := EnvironmentFromName(name); env != USGovernmentCloud {
t.Errorf("Expected to get USGovernmentCloud for %q", name)
}
name = "AzureUSGovernmentCloud"
if env, _ := EnvironmentFromName(name); env != USGovernmentCloud {
t.Errorf("Expected to get USGovernmentCloud for %q", name)
}
name = "thisisnotarealcloudenv"
if _, err := EnvironmentFromName(name); err == nil {
t.Errorf("Expected to get an error for %q", name)
}
}
func TestDeserializeEnvironment(t *testing.T) {
env := `{
"name": "--name--",
"ActiveDirectoryEndpoint": "--active-directory-endpoint--",
"galleryEndpoint": "--gallery-endpoint--",
"graphEndpoint": "--graph-endpoint--",
"keyVaultDNSSuffix": "--key-vault-dns-suffix--",
"keyVaultEndpoint": "--key-vault-endpoint--",
"managementPortalURL": "--management-portal-url--",
"publishSettingsURL": "--publish-settings-url--",
"resourceManagerEndpoint": "--resource-manager-endpoint--",
"serviceBusEndpointSuffix": "--service-bus-endpoint-suffix--",
"serviceManagementEndpoint": "--service-management-endpoint--",
"sqlDatabaseDNSSuffix": "--sql-database-dns-suffix--",
"storageEndpointSuffix": "--storage-endpoint-suffix--",
"trafficManagerDNSSuffix": "--traffic-manager-dns-suffix--",
"serviceManagementVMDNSSuffix": "--asm-vm-dns-suffix--",
"resourceManagerVMDNSSuffix": "--arm-vm-dns-suffix--",
"containerRegistryDNSSuffix": "--container-registry-dns-suffix--"
}`
testSubject := Environment{}
err := json.Unmarshal([]byte(env), &testSubject)
if err != nil {
t.Fatalf("failed to unmarshal: %s", err)
}
if "--name--" != testSubject.Name {
t.Errorf("Expected Name to be \"--name--\", but got %q", testSubject.Name)
}
if "--management-portal-url--" != testSubject.ManagementPortalURL {
t.Errorf("Expected ManagementPortalURL to be \"--management-portal-url--\", but got %q", testSubject.ManagementPortalURL)
}
if "--publish-settings-url--" != testSubject.PublishSettingsURL {
t.Errorf("Expected PublishSettingsURL to be \"--publish-settings-url--\", but got %q", testSubject.PublishSettingsURL)
}
if "--service-management-endpoint--" != testSubject.ServiceManagementEndpoint {
t.Errorf("Expected ServiceManagementEndpoint to be \"--service-management-endpoint--\", but got %q", testSubject.ServiceManagementEndpoint)
}
if "--resource-manager-endpoint--" != testSubject.ResourceManagerEndpoint {
t.Errorf("Expected ResourceManagerEndpoint to be \"--resource-manager-endpoint--\", but got %q", testSubject.ResourceManagerEndpoint)
}
if "--active-directory-endpoint--" != testSubject.ActiveDirectoryEndpoint {
t.Errorf("Expected ActiveDirectoryEndpoint to be \"--active-directory-endpoint--\", but got %q", testSubject.ActiveDirectoryEndpoint)
}
if "--gallery-endpoint--" != testSubject.GalleryEndpoint {
t.Errorf("Expected GalleryEndpoint to be \"--gallery-endpoint--\", but got %q", testSubject.GalleryEndpoint)
}
if "--key-vault-endpoint--" != testSubject.KeyVaultEndpoint {
t.Errorf("Expected KeyVaultEndpoint to be \"--key-vault-endpoint--\", but got %q", testSubject.KeyVaultEndpoint)
}
if "--graph-endpoint--" != testSubject.GraphEndpoint {
t.Errorf("Expected GraphEndpoint to be \"--graph-endpoint--\", but got %q", testSubject.GraphEndpoint)
}
if "--storage-endpoint-suffix--" != testSubject.StorageEndpointSuffix {
t.Errorf("Expected StorageEndpointSuffix to be \"--storage-endpoint-suffix--\", but got %q", testSubject.StorageEndpointSuffix)
}
if "--sql-database-dns-suffix--" != testSubject.SQLDatabaseDNSSuffix {
t.Errorf("Expected sql-database-dns-suffix to be \"--sql-database-dns-suffix--\", but got %q", testSubject.SQLDatabaseDNSSuffix)
}
if "--key-vault-dns-suffix--" != testSubject.KeyVaultDNSSuffix {
t.Errorf("Expected StorageEndpointSuffix to be \"--key-vault-dns-suffix--\", but got %q", testSubject.KeyVaultDNSSuffix)
}
if "--service-bus-endpoint-suffix--" != testSubject.ServiceBusEndpointSuffix {
t.Errorf("Expected StorageEndpointSuffix to be \"--service-bus-endpoint-suffix--\", but got %q", testSubject.ServiceBusEndpointSuffix)
}
if "--asm-vm-dns-suffix--" != testSubject.ServiceManagementVMDNSSuffix {
t.Errorf("Expected ServiceManagementVMDNSSuffix to be \"--asm-vm-dns-suffix--\", but got %q", testSubject.ServiceManagementVMDNSSuffix)
}
if "--arm-vm-dns-suffix--" != testSubject.ResourceManagerVMDNSSuffix {
t.Errorf("Expected ResourceManagerVMDNSSuffix to be \"--arm-vm-dns-suffix--\", but got %q", testSubject.ResourceManagerVMDNSSuffix)
}
if "--container-registry-dns-suffix--" != testSubject.ContainerRegistryDNSSuffix {
t.Errorf("Expected ContainerRegistryDNSSuffix to be \"--container-registry-dns-suffix--\", but got %q", testSubject.ContainerRegistryDNSSuffix)
}
}
func TestRoundTripSerialization(t *testing.T) {
env := Environment{
Name: "--unit-test--",
ManagementPortalURL: "--management-portal-url",
PublishSettingsURL: "--publish-settings-url--",
ServiceManagementEndpoint: "--service-management-endpoint--",
ResourceManagerEndpoint: "--resource-management-endpoint--",
ActiveDirectoryEndpoint: "--active-directory-endpoint--",
GalleryEndpoint: "--gallery-endpoint--",
KeyVaultEndpoint: "--key-vault--endpoint--",
GraphEndpoint: "--graph-endpoint--",
StorageEndpointSuffix: "--storage-endpoint-suffix--",
SQLDatabaseDNSSuffix: "--sql-database-dns-suffix--",
TrafficManagerDNSSuffix: "--traffic-manager-dns-suffix--",
KeyVaultDNSSuffix: "--key-vault-dns-suffix--",
ServiceBusEndpointSuffix: "--service-bus-endpoint-suffix--",
ServiceManagementVMDNSSuffix: "--asm-vm-dns-suffix--",
ResourceManagerVMDNSSuffix: "--arm-vm-dns-suffix--",
ContainerRegistryDNSSuffix: "--container-registry-dns-suffix--",
}
bytes, err := json.Marshal(env)
if err != nil {
t.Fatalf("failed to marshal: %s", err)
}
testSubject := Environment{}
err = json.Unmarshal(bytes, &testSubject)
if err != nil {
t.Fatalf("failed to unmarshal: %s", err)
}
if env.Name != testSubject.Name {
t.Errorf("Expected Name to be %q, but got %q", env.Name, testSubject.Name)
}
if env.ManagementPortalURL != testSubject.ManagementPortalURL {
t.Errorf("Expected ManagementPortalURL to be %q, but got %q", env.ManagementPortalURL, testSubject.ManagementPortalURL)
}
if env.PublishSettingsURL != testSubject.PublishSettingsURL {
t.Errorf("Expected PublishSettingsURL to be %q, but got %q", env.PublishSettingsURL, testSubject.PublishSettingsURL)
}
if env.ServiceManagementEndpoint != testSubject.ServiceManagementEndpoint {
t.Errorf("Expected ServiceManagementEndpoint to be %q, but got %q", env.ServiceManagementEndpoint, testSubject.ServiceManagementEndpoint)
}
if env.ResourceManagerEndpoint != testSubject.ResourceManagerEndpoint {
t.Errorf("Expected ResourceManagerEndpoint to be %q, but got %q", env.ResourceManagerEndpoint, testSubject.ResourceManagerEndpoint)
}
if env.ActiveDirectoryEndpoint != testSubject.ActiveDirectoryEndpoint {
t.Errorf("Expected ActiveDirectoryEndpoint to be %q, but got %q", env.ActiveDirectoryEndpoint, testSubject.ActiveDirectoryEndpoint)
}
if env.GalleryEndpoint != testSubject.GalleryEndpoint {
t.Errorf("Expected GalleryEndpoint to be %q, but got %q", env.GalleryEndpoint, testSubject.GalleryEndpoint)
}
if env.KeyVaultEndpoint != testSubject.KeyVaultEndpoint {
t.Errorf("Expected KeyVaultEndpoint to be %q, but got %q", env.KeyVaultEndpoint, testSubject.KeyVaultEndpoint)
}
if env.GraphEndpoint != testSubject.GraphEndpoint {
t.Errorf("Expected GraphEndpoint to be %q, but got %q", env.GraphEndpoint, testSubject.GraphEndpoint)
}
if env.StorageEndpointSuffix != testSubject.StorageEndpointSuffix {
t.Errorf("Expected StorageEndpointSuffix to be %q, but got %q", env.StorageEndpointSuffix, testSubject.StorageEndpointSuffix)
}
if env.SQLDatabaseDNSSuffix != testSubject.SQLDatabaseDNSSuffix {
t.Errorf("Expected SQLDatabaseDNSSuffix to be %q, but got %q", env.SQLDatabaseDNSSuffix, testSubject.SQLDatabaseDNSSuffix)
}
if env.TrafficManagerDNSSuffix != testSubject.TrafficManagerDNSSuffix {
t.Errorf("Expected TrafficManagerDNSSuffix to be %q, but got %q", env.TrafficManagerDNSSuffix, testSubject.TrafficManagerDNSSuffix)
}
if env.KeyVaultDNSSuffix != testSubject.KeyVaultDNSSuffix {
t.Errorf("Expected KeyVaultDNSSuffix to be %q, but got %q", env.KeyVaultDNSSuffix, testSubject.KeyVaultDNSSuffix)
}
if env.ServiceBusEndpointSuffix != testSubject.ServiceBusEndpointSuffix {
t.Errorf("Expected ServiceBusEndpointSuffix to be %q, but got %q", env.ServiceBusEndpointSuffix, testSubject.ServiceBusEndpointSuffix)
}
if env.ServiceManagementVMDNSSuffix != testSubject.ServiceManagementVMDNSSuffix {
t.Errorf("Expected ServiceManagementVMDNSSuffix to be %q, but got %q", env.ServiceManagementVMDNSSuffix, testSubject.ServiceManagementVMDNSSuffix)
}
if env.ResourceManagerVMDNSSuffix != testSubject.ResourceManagerVMDNSSuffix {
t.Errorf("Expected ResourceManagerVMDNSSuffix to be %q, but got %q", env.ResourceManagerVMDNSSuffix, testSubject.ResourceManagerVMDNSSuffix)
}
if env.ContainerRegistryDNSSuffix != testSubject.ContainerRegistryDNSSuffix {
t.Errorf("Expected ContainerRegistryDNSSuffix to be %q, but got %q", env.ContainerRegistryDNSSuffix, testSubject.ContainerRegistryDNSSuffix)
}
}

View file

@ -44,7 +44,7 @@ func DoRetryWithRegistration(client autorest.Client) autorest.SendDecorator {
return resp, err
}
if resp.StatusCode != http.StatusConflict {
if resp.StatusCode != http.StatusConflict || client.SkipResourceProviderRegistration {
return resp, err
}
var re RequestError
@ -159,7 +159,7 @@ func register(client autorest.Client, originalReq *http.Request, re RequestError
}
req.Cancel = originalReq.Cancel
resp, err := autorest.SendWithSender(client.Sender, req,
resp, err := autorest.SendWithSender(client, req,
autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...),
)
if err != nil {

View file

@ -1,81 +0,0 @@
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package azure
import (
"net/http"
"testing"
"time"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/mocks"
)
func TestDoRetryWithRegistration(t *testing.T) {
client := mocks.NewSender()
// first response, should retry because it is a transient error
client.AppendResponse(mocks.NewResponseWithStatus("Internal server error", http.StatusInternalServerError))
// response indicates the resource provider has not been registered
client.AppendResponse(mocks.NewResponseWithBodyAndStatus(mocks.NewBody(`{
"error":{
"code":"MissingSubscriptionRegistration",
"message":"The subscription registration is in 'Unregistered' state. The subscription must be registered to use namespace 'Microsoft.EventGrid'. See https://aka.ms/rps-not-found for how to register subscriptions.",
"details":[
{
"code":"MissingSubscriptionRegistration",
"target":"Microsoft.EventGrid",
"message":"The subscription registration is in 'Unregistered' state. The subscription must be registered to use namespace 'Microsoft.EventGrid'. See https://aka.ms/rps-not-found for how to register subscriptions."
}
]
}
}
`), http.StatusConflict, "MissingSubscriptionRegistration"))
// first poll response, still not ready
client.AppendResponse(mocks.NewResponseWithBodyAndStatus(mocks.NewBody(`{
"registrationState": "Registering"
}
`), http.StatusOK, "200 OK"))
// last poll response, respurce provider has been registered
client.AppendResponse(mocks.NewResponseWithBodyAndStatus(mocks.NewBody(`{
"registrationState": "Registered"
}
`), http.StatusOK, "200 OK"))
// retry original request, response is successful
client.AppendResponse(mocks.NewResponseWithStatus("200 OK", http.StatusOK))
req := mocks.NewRequestForURL("https://lol/subscriptions/rofl")
req.Body = mocks.NewBody("lolol")
r, err := autorest.SendWithSender(client, req,
DoRetryWithRegistration(autorest.Client{
PollingDelay: time.Second,
PollingDuration: time.Second * 10,
RetryAttempts: 5,
RetryDuration: time.Second,
Sender: client,
}),
)
if err != nil {
t.Fatalf("got error: %v", err)
}
autorest.Respond(r,
autorest.ByDiscardingBody(),
autorest.ByClosing(),
)
if r.StatusCode != http.StatusOK {
t.Fatalf("azure: Sender#DoRetryWithRegistration -- Got: StatusCode %v; Want: StatusCode 200 OK", r.StatusCode)
}
}

View file

@ -166,6 +166,9 @@ type Client struct {
UserAgent string
Jar http.CookieJar
// Set to true to skip attempted registration of resource providers (false by default).
SkipResourceProviderRegistration bool
}
// NewClientWithUserAgent returns an instance of a Client with the UserAgent set to the passed
@ -204,7 +207,13 @@ func (c Client) Do(r *http.Request) (*http.Response, error) {
c.WithInspection(),
c.WithAuthorization())
if err != nil {
return nil, NewErrorWithError(err, "autorest/Client", "Do", nil, "Preparing request failed")
var resp *http.Response
if detErr, ok := err.(DetailedError); ok {
// if the authorization failed (e.g. invalid credentials) there will
// be a response associated with the error, be sure to return it.
resp = detErr.Response
}
return resp, NewErrorWithError(err, "autorest/Client", "Do", nil, "Preparing request failed")
}
resp, err := SendWithSender(c.sender(), r)

View file

@ -1,402 +0,0 @@
package autorest
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"bytes"
"fmt"
"io/ioutil"
"log"
"math/rand"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"time"
"github.com/Azure/go-autorest/autorest/mocks"
)
func TestLoggingInspectorWithInspection(t *testing.T) {
b := bytes.Buffer{}
c := Client{}
li := LoggingInspector{Logger: log.New(&b, "", 0)}
c.RequestInspector = li.WithInspection()
Prepare(mocks.NewRequestWithContent("Content"),
c.WithInspection())
if len(b.String()) <= 0 {
t.Fatal("autorest: LoggingInspector#WithInspection did not record Request to the log")
}
}
func TestLoggingInspectorWithInspectionEmitsErrors(t *testing.T) {
b := bytes.Buffer{}
c := Client{}
r := mocks.NewRequestWithContent("Content")
li := LoggingInspector{Logger: log.New(&b, "", 0)}
c.RequestInspector = li.WithInspection()
if _, err := Prepare(r,
c.WithInspection()); err != nil {
t.Error(err)
}
if len(b.String()) <= 0 {
t.Fatal("autorest: LoggingInspector#WithInspection did not record Request to the log")
}
}
func TestLoggingInspectorWithInspectionRestoresBody(t *testing.T) {
b := bytes.Buffer{}
c := Client{}
r := mocks.NewRequestWithContent("Content")
li := LoggingInspector{Logger: log.New(&b, "", 0)}
c.RequestInspector = li.WithInspection()
Prepare(r,
c.WithInspection())
s, _ := ioutil.ReadAll(r.Body)
if len(s) <= 0 {
t.Fatal("autorest: LoggingInspector#WithInspection did not restore the Request body")
}
}
func TestLoggingInspectorByInspecting(t *testing.T) {
b := bytes.Buffer{}
c := Client{}
li := LoggingInspector{Logger: log.New(&b, "", 0)}
c.ResponseInspector = li.ByInspecting()
Respond(mocks.NewResponseWithContent("Content"),
c.ByInspecting())
if len(b.String()) <= 0 {
t.Fatal("autorest: LoggingInspector#ByInspection did not record Response to the log")
}
}
func TestLoggingInspectorByInspectingEmitsErrors(t *testing.T) {
b := bytes.Buffer{}
c := Client{}
r := mocks.NewResponseWithContent("Content")
li := LoggingInspector{Logger: log.New(&b, "", 0)}
c.ResponseInspector = li.ByInspecting()
if err := Respond(r,
c.ByInspecting()); err != nil {
t.Fatal(err)
}
if len(b.String()) <= 0 {
t.Fatal("autorest: LoggingInspector#ByInspection did not record Response to the log")
}
}
func TestLoggingInspectorByInspectingRestoresBody(t *testing.T) {
b := bytes.Buffer{}
c := Client{}
r := mocks.NewResponseWithContent("Content")
li := LoggingInspector{Logger: log.New(&b, "", 0)}
c.ResponseInspector = li.ByInspecting()
Respond(r,
c.ByInspecting())
s, _ := ioutil.ReadAll(r.Body)
if len(s) <= 0 {
t.Fatal("autorest: LoggingInspector#ByInspecting did not restore the Response body")
}
}
func TestNewClientWithUserAgent(t *testing.T) {
ua := "UserAgent"
c := NewClientWithUserAgent(ua)
completeUA := fmt.Sprintf("%s %s", defaultUserAgent, ua)
if c.UserAgent != completeUA {
t.Fatalf("autorest: NewClientWithUserAgent failed to set the UserAgent -- expected %s, received %s",
completeUA, c.UserAgent)
}
}
func TestAddToUserAgent(t *testing.T) {
ua := "UserAgent"
c := NewClientWithUserAgent(ua)
ext := "extension"
err := c.AddToUserAgent(ext)
if err != nil {
t.Fatalf("autorest: AddToUserAgent returned error -- expected nil, received %s", err)
}
completeUA := fmt.Sprintf("%s %s %s", defaultUserAgent, ua, ext)
if c.UserAgent != completeUA {
t.Fatalf("autorest: AddToUserAgent failed to add an extension to the UserAgent -- expected %s, received %s",
completeUA, c.UserAgent)
}
err = c.AddToUserAgent("")
if err == nil {
t.Fatalf("autorest: AddToUserAgent didn't return error -- expected %s, received nil",
fmt.Errorf("Extension was empty, User Agent stayed as %s", c.UserAgent))
}
if c.UserAgent != completeUA {
t.Fatalf("autorest: AddToUserAgent failed to not add an empty extension to the UserAgent -- expected %s, received %s",
completeUA, c.UserAgent)
}
}
func TestClientSenderReturnsHttpClientByDefault(t *testing.T) {
c := Client{}
if fmt.Sprintf("%T", c.sender()) != "*http.Client" {
t.Fatal("autorest: Client#sender failed to return http.Client by default")
}
}
func TestClientSenderReturnsSetSender(t *testing.T) {
c := Client{}
s := mocks.NewSender()
c.Sender = s
if c.sender() != s {
t.Fatal("autorest: Client#sender failed to return set Sender")
}
}
func TestClientDoInvokesSender(t *testing.T) {
c := Client{}
s := mocks.NewSender()
c.Sender = s
c.Do(&http.Request{})
if s.Attempts() != 1 {
t.Fatal("autorest: Client#Do failed to invoke the Sender")
}
}
func TestClientDoSetsUserAgent(t *testing.T) {
ua := "UserAgent"
c := Client{UserAgent: ua}
r := mocks.NewRequest()
s := mocks.NewSender()
c.Sender = s
c.Do(r)
if r.UserAgent() != ua {
t.Fatalf("autorest: Client#Do failed to correctly set User-Agent header: %s=%s",
http.CanonicalHeaderKey(headerUserAgent), r.UserAgent())
}
}
func TestClientDoSetsAuthorization(t *testing.T) {
r := mocks.NewRequest()
s := mocks.NewSender()
c := Client{Authorizer: mockAuthorizer{}, Sender: s}
c.Do(r)
if len(r.Header.Get(http.CanonicalHeaderKey(headerAuthorization))) <= 0 {
t.Fatalf("autorest: Client#Send failed to set Authorization header -- %s=%s",
http.CanonicalHeaderKey(headerAuthorization),
r.Header.Get(http.CanonicalHeaderKey(headerAuthorization)))
}
}
func TestClientDoInvokesRequestInspector(t *testing.T) {
r := mocks.NewRequest()
s := mocks.NewSender()
i := &mockInspector{}
c := Client{RequestInspector: i.WithInspection(), Sender: s}
c.Do(r)
if !i.wasInvoked {
t.Fatal("autorest: Client#Send failed to invoke the RequestInspector")
}
}
func TestClientDoInvokesResponseInspector(t *testing.T) {
r := mocks.NewRequest()
s := mocks.NewSender()
i := &mockInspector{}
c := Client{ResponseInspector: i.ByInspecting(), Sender: s}
c.Do(r)
if !i.wasInvoked {
t.Fatal("autorest: Client#Send failed to invoke the ResponseInspector")
}
}
func TestClientDoReturnsErrorIfPrepareFails(t *testing.T) {
c := Client{}
s := mocks.NewSender()
c.Authorizer = mockFailingAuthorizer{}
c.Sender = s
_, err := c.Do(&http.Request{})
if err == nil {
t.Fatalf("autorest: Client#Do failed to return an error when Prepare failed")
}
}
func TestClientDoDoesNotSendIfPrepareFails(t *testing.T) {
c := Client{}
s := mocks.NewSender()
c.Authorizer = mockFailingAuthorizer{}
c.Sender = s
c.Do(&http.Request{})
if s.Attempts() > 0 {
t.Fatal("autorest: Client#Do failed to invoke the Sender")
}
}
func TestClientAuthorizerReturnsNullAuthorizerByDefault(t *testing.T) {
c := Client{}
if fmt.Sprintf("%T", c.authorizer()) != "autorest.NullAuthorizer" {
t.Fatal("autorest: Client#authorizer failed to return the NullAuthorizer by default")
}
}
func TestClientAuthorizerReturnsSetAuthorizer(t *testing.T) {
c := Client{}
c.Authorizer = mockAuthorizer{}
if fmt.Sprintf("%T", c.authorizer()) != "autorest.mockAuthorizer" {
t.Fatal("autorest: Client#authorizer failed to return the set Authorizer")
}
}
func TestClientWithAuthorizer(t *testing.T) {
c := Client{}
c.Authorizer = mockAuthorizer{}
req, _ := Prepare(&http.Request{},
c.WithAuthorization())
if req.Header.Get(headerAuthorization) == "" {
t.Fatal("autorest: Client#WithAuthorizer failed to return the WithAuthorizer from the active Authorizer")
}
}
func TestClientWithInspection(t *testing.T) {
c := Client{}
r := &mockInspector{}
c.RequestInspector = r.WithInspection()
Prepare(&http.Request{},
c.WithInspection())
if !r.wasInvoked {
t.Fatal("autorest: Client#WithInspection failed to invoke RequestInspector")
}
}
func TestClientWithInspectionSetsDefault(t *testing.T) {
c := Client{}
r1 := &http.Request{}
r2, _ := Prepare(r1,
c.WithInspection())
if !reflect.DeepEqual(r1, r2) {
t.Fatal("autorest: Client#WithInspection failed to provide a default RequestInspector")
}
}
func TestClientByInspecting(t *testing.T) {
c := Client{}
r := &mockInspector{}
c.ResponseInspector = r.ByInspecting()
Respond(&http.Response{},
c.ByInspecting())
if !r.wasInvoked {
t.Fatal("autorest: Client#ByInspecting failed to invoke ResponseInspector")
}
}
func TestClientByInspectingSetsDefault(t *testing.T) {
c := Client{}
r := &http.Response{}
Respond(r,
c.ByInspecting())
if !reflect.DeepEqual(r, &http.Response{}) {
t.Fatal("autorest: Client#ByInspecting failed to provide a default ResponseInspector")
}
}
func TestCookies(t *testing.T) {
second := "second"
expected := http.Cookie{
Name: "tastes",
Value: "delicious",
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.SetCookie(w, &expected)
b, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("autorest: ioutil.ReadAll failed reading request body: %s", err)
}
if string(b) == second {
cookie, err := r.Cookie(expected.Name)
if err != nil {
t.Fatalf("autorest: r.Cookie could not get request cookie: %s", err)
}
if cookie == nil {
t.Fatalf("autorest: got nil cookie, expecting %v", expected)
}
if cookie.Value != expected.Value {
t.Fatalf("autorest: got cookie value '%s', expecting '%s'", cookie.Value, expected.Name)
}
}
}))
defer server.Close()
client := NewClientWithUserAgent("")
_, err := SendWithSender(client, mocks.NewRequestForURL(server.URL))
if err != nil {
t.Fatalf("autorest: first request failed: %s", err)
}
r2, err := http.NewRequest(http.MethodGet, server.URL, mocks.NewBody(second))
if err != nil {
t.Fatalf("autorest: failed creating second request: %s", err)
}
_, err = SendWithSender(client, r2)
if err != nil {
t.Fatalf("autorest: second request failed: %s", err)
}
}
func randomString(n int) string {
const chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
r := rand.New(rand.NewSource(time.Now().UTC().UnixNano()))
s := make([]byte, n)
for i := range s {
s[i] = chars[r.Intn(len(chars))]
}
return string(s)
}

View file

@ -1,237 +0,0 @@
package date
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"encoding/json"
"fmt"
"reflect"
"testing"
"time"
)
func ExampleParseDate() {
d, err := ParseDate("2001-02-03")
if err != nil {
fmt.Println(err)
}
fmt.Println(d)
// Output: 2001-02-03
}
func ExampleDate() {
d, err := ParseDate("2001-02-03")
if err != nil {
fmt.Println(err)
}
t, err := time.Parse(time.RFC3339, "2001-02-04T00:00:00Z")
if err != nil {
fmt.Println(err)
}
// Date acts as time.Time when the receiver
if d.Before(t) {
fmt.Printf("Before ")
} else {
fmt.Printf("After ")
}
// Convert Date when needing a time.Time
if t.After(d.ToTime()) {
fmt.Printf("After")
} else {
fmt.Printf("Before")
}
// Output: Before After
}
func ExampleDate_MarshalBinary() {
d, err := ParseDate("2001-02-03")
if err != nil {
fmt.Println(err)
}
t, err := d.MarshalBinary()
if err != nil {
fmt.Println(err)
}
fmt.Println(string(t))
// Output: 2001-02-03
}
func ExampleDate_UnmarshalBinary() {
d := Date{}
t := "2001-02-03"
if err := d.UnmarshalBinary([]byte(t)); err != nil {
fmt.Println(err)
}
fmt.Println(d)
// Output: 2001-02-03
}
func ExampleDate_MarshalJSON() {
d, err := ParseDate("2001-02-03")
if err != nil {
fmt.Println(err)
}
j, err := json.Marshal(d)
if err != nil {
fmt.Println(err)
}
fmt.Println(string(j))
// Output: "2001-02-03"
}
func ExampleDate_UnmarshalJSON() {
var d struct {
Date Date `json:"date"`
}
j := `{"date" : "2001-02-03"}`
if err := json.Unmarshal([]byte(j), &d); err != nil {
fmt.Println(err)
}
fmt.Println(d.Date)
// Output: 2001-02-03
}
func ExampleDate_MarshalText() {
d, err := ParseDate("2001-02-03")
if err != nil {
fmt.Println(err)
}
t, err := d.MarshalText()
if err != nil {
fmt.Println(err)
}
fmt.Println(string(t))
// Output: 2001-02-03
}
func ExampleDate_UnmarshalText() {
d := Date{}
t := "2001-02-03"
if err := d.UnmarshalText([]byte(t)); err != nil {
fmt.Println(err)
}
fmt.Println(d)
// Output: 2001-02-03
}
func TestDateString(t *testing.T) {
d, err := ParseDate("2001-02-03")
if err != nil {
t.Fatalf("date: String failed (%v)", err)
}
if d.String() != "2001-02-03" {
t.Fatalf("date: String failed (%v)", d.String())
}
}
func TestDateBinaryRoundTrip(t *testing.T) {
d1, err := ParseDate("2001-02-03")
if err != nil {
t.Fatalf("date: ParseDate failed (%v)", err)
}
t1, err := d1.MarshalBinary()
if err != nil {
t.Fatalf("date: MarshalBinary failed (%v)", err)
}
d2 := Date{}
if err = d2.UnmarshalBinary(t1); err != nil {
t.Fatalf("date: UnmarshalBinary failed (%v)", err)
}
if !reflect.DeepEqual(d1, d2) {
t.Fatalf("date: Round-trip Binary failed (%v, %v)", d1, d2)
}
}
func TestDateJSONRoundTrip(t *testing.T) {
type s struct {
Date Date `json:"date"`
}
var err error
d1 := s{}
d1.Date, err = ParseDate("2001-02-03")
if err != nil {
t.Fatalf("date: ParseDate failed (%v)", err)
}
j, err := json.Marshal(d1)
if err != nil {
t.Fatalf("date: MarshalJSON failed (%v)", err)
}
d2 := s{}
if err = json.Unmarshal(j, &d2); err != nil {
t.Fatalf("date: UnmarshalJSON failed (%v)", err)
}
if !reflect.DeepEqual(d1, d2) {
t.Fatalf("date: Round-trip JSON failed (%v, %v)", d1, d2)
}
}
func TestDateTextRoundTrip(t *testing.T) {
d1, err := ParseDate("2001-02-03")
if err != nil {
t.Fatalf("date: ParseDate failed (%v)", err)
}
t1, err := d1.MarshalText()
if err != nil {
t.Fatalf("date: MarshalText failed (%v)", err)
}
d2 := Date{}
if err = d2.UnmarshalText(t1); err != nil {
t.Fatalf("date: UnmarshalText failed (%v)", err)
}
if !reflect.DeepEqual(d1, d2) {
t.Fatalf("date: Round-trip Text failed (%v, %v)", d1, d2)
}
}
func TestDateToTime(t *testing.T) {
var d Date
d, err := ParseDate("2001-02-03")
if err != nil {
t.Fatalf("date: ParseDate failed (%v)", err)
}
var _ time.Time = d.ToTime()
}
func TestDateUnmarshalJSONReturnsError(t *testing.T) {
var d struct {
Date Date `json:"date"`
}
j := `{"date" : "February 3, 2001"}`
if err := json.Unmarshal([]byte(j), &d); err == nil {
t.Fatal("date: Date failed to return error for malformed JSON date")
}
}
func TestDateUnmarshalTextReturnsError(t *testing.T) {
d := Date{}
txt := "February 3, 2001"
if err := d.UnmarshalText([]byte(txt)); err == nil {
t.Fatal("date: Date failed to return error for malformed Text date")
}
}

View file

@ -1,277 +0,0 @@
package date
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"encoding/json"
"fmt"
"reflect"
"testing"
"time"
)
func ExampleParseTime() {
d, _ := ParseTime(rfc3339, "2001-02-03T04:05:06Z")
fmt.Println(d)
// Output: 2001-02-03 04:05:06 +0000 UTC
}
func ExampleTime_MarshalBinary() {
ti, err := ParseTime(rfc3339, "2001-02-03T04:05:06Z")
if err != nil {
fmt.Println(err)
}
d := Time{ti}
t, err := d.MarshalBinary()
if err != nil {
fmt.Println(err)
}
fmt.Println(string(t))
// Output: 2001-02-03T04:05:06Z
}
func ExampleTime_UnmarshalBinary() {
d := Time{}
t := "2001-02-03T04:05:06Z"
if err := d.UnmarshalBinary([]byte(t)); err != nil {
fmt.Println(err)
}
fmt.Println(d)
// Output: 2001-02-03T04:05:06Z
}
func ExampleTime_MarshalJSON() {
d, err := ParseTime(rfc3339, "2001-02-03T04:05:06Z")
if err != nil {
fmt.Println(err)
}
j, err := json.Marshal(d)
if err != nil {
fmt.Println(err)
}
fmt.Println(string(j))
// Output: "2001-02-03T04:05:06Z"
}
func ExampleTime_UnmarshalJSON() {
var d struct {
Time Time `json:"datetime"`
}
j := `{"datetime" : "2001-02-03T04:05:06Z"}`
if err := json.Unmarshal([]byte(j), &d); err != nil {
fmt.Println(err)
}
fmt.Println(d.Time)
// Output: 2001-02-03T04:05:06Z
}
func ExampleTime_MarshalText() {
d, err := ParseTime(rfc3339, "2001-02-03T04:05:06Z")
if err != nil {
fmt.Println(err)
}
t, err := d.MarshalText()
if err != nil {
fmt.Println(err)
}
fmt.Println(string(t))
// Output: 2001-02-03T04:05:06Z
}
func ExampleTime_UnmarshalText() {
d := Time{}
t := "2001-02-03T04:05:06Z"
if err := d.UnmarshalText([]byte(t)); err != nil {
fmt.Println(err)
}
fmt.Println(d)
// Output: 2001-02-03T04:05:06Z
}
func TestUnmarshalTextforInvalidDate(t *testing.T) {
d := Time{}
dt := "2001-02-03T04:05:06AAA"
if err := d.UnmarshalText([]byte(dt)); err == nil {
t.Fatalf("date: Time#Unmarshal was expecting error for invalid date")
}
}
func TestUnmarshalJSONforInvalidDate(t *testing.T) {
d := Time{}
dt := `"2001-02-03T04:05:06AAA"`
if err := d.UnmarshalJSON([]byte(dt)); err == nil {
t.Fatalf("date: Time#Unmarshal was expecting error for invalid date")
}
}
func TestTimeString(t *testing.T) {
ti, err := ParseTime(rfc3339, "2001-02-03T04:05:06Z")
if err != nil {
fmt.Println(err)
}
d := Time{ti}
if d.String() != "2001-02-03T04:05:06Z" {
t.Fatalf("date: Time#String failed (%v)", d.String())
}
}
func TestTimeStringReturnsEmptyStringForError(t *testing.T) {
d := Time{Time: time.Date(20000, 01, 01, 01, 01, 01, 01, time.UTC)}
if d.String() != "" {
t.Fatalf("date: Time#String failed empty string for an error")
}
}
func TestTimeBinaryRoundTrip(t *testing.T) {
ti, err := ParseTime(rfc3339, "2001-02-03T04:05:06Z")
if err != nil {
t.Fatalf("date: Time#ParseTime failed (%v)", err)
}
d1 := Time{ti}
t1, err := d1.MarshalBinary()
if err != nil {
t.Fatalf("date: Time#MarshalBinary failed (%v)", err)
}
d2 := Time{}
if err = d2.UnmarshalBinary(t1); err != nil {
t.Fatalf("date: Time#UnmarshalBinary failed (%v)", err)
}
if !reflect.DeepEqual(d1, d2) {
t.Fatalf("date:Round-trip Binary failed (%v, %v)", d1, d2)
}
}
func TestTimeJSONRoundTrip(t *testing.T) {
type s struct {
Time Time `json:"datetime"`
}
ti, err := ParseTime(rfc3339, "2001-02-03T04:05:06Z")
if err != nil {
t.Fatalf("date: Time#ParseTime failed (%v)", err)
}
d1 := s{Time: Time{ti}}
j, err := json.Marshal(d1)
if err != nil {
t.Fatalf("date: Time#MarshalJSON failed (%v)", err)
}
d2 := s{}
if err = json.Unmarshal(j, &d2); err != nil {
t.Fatalf("date: Time#UnmarshalJSON failed (%v)", err)
}
if !reflect.DeepEqual(d1, d2) {
t.Fatalf("date: Round-trip JSON failed (%v, %v)", d1, d2)
}
}
func TestTimeTextRoundTrip(t *testing.T) {
ti, err := ParseTime(rfc3339, "2001-02-03T04:05:06Z")
if err != nil {
t.Fatalf("date: Time#ParseTime failed (%v)", err)
}
d1 := Time{Time: ti}
t1, err := d1.MarshalText()
if err != nil {
t.Fatalf("date: Time#MarshalText failed (%v)", err)
}
d2 := Time{}
if err = d2.UnmarshalText(t1); err != nil {
t.Fatalf("date: Time#UnmarshalText failed (%v)", err)
}
if !reflect.DeepEqual(d1, d2) {
t.Fatalf("date: Round-trip Text failed (%v, %v)", d1, d2)
}
}
func TestTimeToTime(t *testing.T) {
ti, err := ParseTime(rfc3339, "2001-02-03T04:05:06Z")
d := Time{ti}
if err != nil {
t.Fatalf("date: Time#ParseTime failed (%v)", err)
}
var _ time.Time = d.ToTime()
}
func TestUnmarshalJSONNoOffset(t *testing.T) {
var d struct {
Time Time `json:"datetime"`
}
j := `{"datetime" : "2001-02-03T04:05:06.789"}`
if err := json.Unmarshal([]byte(j), &d); err != nil {
t.Fatalf("date: Time#Unmarshal failed (%v)", err)
}
}
func TestUnmarshalJSONPosOffset(t *testing.T) {
var d struct {
Time Time `json:"datetime"`
}
j := `{"datetime" : "1980-01-02T00:11:35.01+01:00"}`
if err := json.Unmarshal([]byte(j), &d); err != nil {
t.Fatalf("date: Time#Unmarshal failed (%v)", err)
}
}
func TestUnmarshalJSONNegOffset(t *testing.T) {
var d struct {
Time Time `json:"datetime"`
}
j := `{"datetime" : "1492-10-12T10:15:01.789-08:00"}`
if err := json.Unmarshal([]byte(j), &d); err != nil {
t.Fatalf("date: Time#Unmarshal failed (%v)", err)
}
}
func TestUnmarshalTextNoOffset(t *testing.T) {
d := Time{}
t1 := "2001-02-03T04:05:06"
if err := d.UnmarshalText([]byte(t1)); err != nil {
t.Fatalf("date: Time#UnmarshalText failed (%v)", err)
}
}
func TestUnmarshalTextPosOffset(t *testing.T) {
d := Time{}
t1 := "2001-02-03T04:05:06+00:30"
if err := d.UnmarshalText([]byte(t1)); err != nil {
t.Fatalf("date: Time#UnmarshalText failed (%v)", err)
}
}
func TestUnmarshalTextNegOffset(t *testing.T) {
d := Time{}
t1 := "2001-02-03T04:05:06-11:00"
if err := d.UnmarshalText([]byte(t1)); err != nil {
t.Fatalf("date: Time#UnmarshalText failed (%v)", err)
}
}

View file

@ -1,226 +0,0 @@
package date
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"encoding/json"
"fmt"
"reflect"
"testing"
"time"
)
func ExampleTimeRFC1123() {
d, err := ParseTime(rfc1123, "Mon, 02 Jan 2006 15:04:05 MST")
if err != nil {
fmt.Println(err)
}
fmt.Println(d)
// Output: 2006-01-02 15:04:05 +0000 MST
}
func ExampleTimeRFC1123_MarshalBinary() {
ti, err := ParseTime(rfc1123, "Mon, 02 Jan 2006 15:04:05 MST")
if err != nil {
fmt.Println(err)
}
d := TimeRFC1123{ti}
b, err := d.MarshalBinary()
if err != nil {
fmt.Println(err)
}
fmt.Println(string(b))
// Output: Mon, 02 Jan 2006 15:04:05 MST
}
func ExampleTimeRFC1123_UnmarshalBinary() {
d := TimeRFC1123{}
t := "Mon, 02 Jan 2006 15:04:05 MST"
if err := d.UnmarshalBinary([]byte(t)); err != nil {
fmt.Println(err)
}
fmt.Println(d)
// Output: Mon, 02 Jan 2006 15:04:05 MST
}
func ExampleTimeRFC1123_MarshalJSON() {
ti, err := ParseTime(rfc1123, "Mon, 02 Jan 2006 15:04:05 MST")
if err != nil {
fmt.Println(err)
}
d := TimeRFC1123{ti}
j, err := json.Marshal(d)
if err != nil {
fmt.Println(err)
}
fmt.Println(string(j))
// Output: "Mon, 02 Jan 2006 15:04:05 MST"
}
func TestTimeRFC1123MarshalJSONInvalid(t *testing.T) {
ti := time.Date(20000, 01, 01, 00, 00, 00, 00, time.UTC)
d := TimeRFC1123{ti}
if _, err := json.Marshal(d); err == nil {
t.Fatalf("date: TimeRFC1123#Marshal failed for invalid date")
}
}
func ExampleTimeRFC1123_UnmarshalJSON() {
var d struct {
Time TimeRFC1123 `json:"datetime"`
}
j := `{"datetime" : "Mon, 02 Jan 2006 15:04:05 MST"}`
if err := json.Unmarshal([]byte(j), &d); err != nil {
fmt.Println(err)
}
fmt.Println(d.Time)
// Output: Mon, 02 Jan 2006 15:04:05 MST
}
func ExampleTimeRFC1123_MarshalText() {
ti, err := ParseTime(rfc3339, "2001-02-03T04:05:06Z")
if err != nil {
fmt.Println(err)
}
d := TimeRFC1123{ti}
t, err := d.MarshalText()
if err != nil {
fmt.Println(err)
}
fmt.Println(string(t))
// Output: Sat, 03 Feb 2001 04:05:06 UTC
}
func ExampleTimeRFC1123_UnmarshalText() {
d := TimeRFC1123{}
t := "Sat, 03 Feb 2001 04:05:06 UTC"
if err := d.UnmarshalText([]byte(t)); err != nil {
fmt.Println(err)
}
fmt.Println(d)
// Output: Sat, 03 Feb 2001 04:05:06 UTC
}
func TestUnmarshalJSONforInvalidDateRfc1123(t *testing.T) {
dt := `"Mon, 02 Jan 2000000 15:05 MST"`
d := TimeRFC1123{}
if err := d.UnmarshalJSON([]byte(dt)); err == nil {
t.Fatalf("date: TimeRFC1123#Unmarshal failed for invalid date")
}
}
func TestUnmarshalTextforInvalidDateRfc1123(t *testing.T) {
dt := "Mon, 02 Jan 2000000 15:05 MST"
d := TimeRFC1123{}
if err := d.UnmarshalText([]byte(dt)); err == nil {
t.Fatalf("date: TimeRFC1123#Unmarshal failed for invalid date")
}
}
func TestTimeStringRfc1123(t *testing.T) {
ti, err := ParseTime(rfc1123, "Mon, 02 Jan 2006 15:04:05 MST")
if err != nil {
fmt.Println(err)
}
d := TimeRFC1123{ti}
if d.String() != "Mon, 02 Jan 2006 15:04:05 MST" {
t.Fatalf("date: TimeRFC1123#String failed (%v)", d.String())
}
}
func TestTimeStringReturnsEmptyStringForErrorRfc1123(t *testing.T) {
d := TimeRFC1123{Time: time.Date(20000, 01, 01, 01, 01, 01, 01, time.UTC)}
if d.String() != "" {
t.Fatalf("date: TimeRFC1123#String failed empty string for an error")
}
}
func TestTimeBinaryRoundTripRfc1123(t *testing.T) {
ti, err := ParseTime(rfc3339, "2001-02-03T04:05:06Z")
if err != nil {
t.Fatalf("date: TimeRFC1123#ParseTime failed (%v)", err)
}
d1 := TimeRFC1123{ti}
t1, err := d1.MarshalBinary()
if err != nil {
t.Fatalf("date: TimeRFC1123#MarshalBinary failed (%v)", err)
}
d2 := TimeRFC1123{}
if err = d2.UnmarshalBinary(t1); err != nil {
t.Fatalf("date: TimeRFC1123#UnmarshalBinary failed (%v)", err)
}
if !reflect.DeepEqual(d1, d2) {
t.Fatalf("date: Round-trip Binary failed (%v, %v)", d1, d2)
}
}
func TestTimeJSONRoundTripRfc1123(t *testing.T) {
type s struct {
Time TimeRFC1123 `json:"datetime"`
}
var err error
ti, err := ParseTime(rfc1123, "Mon, 02 Jan 2006 15:04:05 MST")
if err != nil {
t.Fatalf("date: TimeRFC1123#ParseTime failed (%v)", err)
}
d1 := s{Time: TimeRFC1123{ti}}
j, err := json.Marshal(d1)
if err != nil {
t.Fatalf("date: TimeRFC1123#MarshalJSON failed (%v)", err)
}
d2 := s{}
if err = json.Unmarshal(j, &d2); err != nil {
t.Fatalf("date: TimeRFC1123#UnmarshalJSON failed (%v)", err)
}
if !reflect.DeepEqual(d1, d2) {
t.Fatalf("date: Round-trip JSON failed (%v, %v)", d1, d2)
}
}
func TestTimeTextRoundTripRfc1123(t *testing.T) {
ti, err := ParseTime(rfc1123, "Mon, 02 Jan 2006 15:04:05 MST")
if err != nil {
t.Fatalf("date: TimeRFC1123#ParseTime failed (%v)", err)
}
d1 := TimeRFC1123{Time: ti}
t1, err := d1.MarshalText()
if err != nil {
t.Fatalf("date: TimeRFC1123#MarshalText failed (%v)", err)
}
d2 := TimeRFC1123{}
if err = d2.UnmarshalText(t1); err != nil {
t.Fatalf("date: TimeRFC1123#UnmarshalText failed (%v)", err)
}
if !reflect.DeepEqual(d1, d2) {
t.Fatalf("date: Round-trip Text failed (%v, %v)", d1, d2)
}
}
func TestTimeToTimeRFC1123(t *testing.T) {
ti, err := ParseTime(rfc1123, "Mon, 02 Jan 2006 15:04:05 MST")
d := TimeRFC1123{ti}
if err != nil {
t.Fatalf("date: TimeRFC1123#ParseTime failed (%v)", err)
}
var _ time.Time = d.ToTime()
}

View file

@ -1,283 +0,0 @@
// +build go1.7
package date
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"bytes"
"encoding/binary"
"encoding/json"
"fmt"
"math"
"testing"
"time"
)
func ExampleUnixTime_MarshalJSON() {
epoch := UnixTime(UnixEpoch())
text, _ := json.Marshal(epoch)
fmt.Print(string(text))
// Output: 0
}
func ExampleUnixTime_UnmarshalJSON() {
var myTime UnixTime
json.Unmarshal([]byte("1.3e2"), &myTime)
fmt.Printf("%v", time.Time(myTime))
// Output: 1970-01-01 00:02:10 +0000 UTC
}
func TestUnixTime_MarshalJSON(t *testing.T) {
testCases := []time.Time{
UnixEpoch().Add(-1 * time.Second), // One second befote the Unix Epoch
time.Date(2017, time.April, 14, 20, 27, 47, 0, time.UTC), // The time this test was written
UnixEpoch(),
time.Date(1800, 01, 01, 0, 0, 0, 0, time.UTC),
time.Date(2200, 12, 29, 00, 01, 37, 82, time.UTC),
}
for _, tc := range testCases {
t.Run(tc.String(), func(subT *testing.T) {
var actual, expected float64
var marshaled []byte
target := UnixTime(tc)
expected = float64(target.Duration().Nanoseconds()) / 1e9
if temp, err := json.Marshal(target); err == nil {
marshaled = temp
} else {
subT.Error(err)
return
}
dec := json.NewDecoder(bytes.NewReader(marshaled))
if err := dec.Decode(&actual); err != nil {
subT.Error(err)
return
}
diff := math.Abs(actual - expected)
subT.Logf("\ngot :\t%g\nwant:\t%g\ndiff:\t%g", actual, expected, diff)
if diff > 1e-9 { //Must be within 1 nanosecond of one another
subT.Fail()
}
})
}
}
func TestUnixTime_UnmarshalJSON(t *testing.T) {
testCases := []struct {
text string
expected time.Time
}{
{"1", UnixEpoch().Add(time.Second)},
{"0", UnixEpoch()},
{"1492203742", time.Date(2017, time.April, 14, 21, 02, 22, 0, time.UTC)}, // The time this test was written
{"-1", time.Date(1969, time.December, 31, 23, 59, 59, 0, time.UTC)},
{"1.5", UnixEpoch().Add(1500 * time.Millisecond)},
{"0e1", UnixEpoch()}, // See http://json.org for 'number' format definition.
{"1.3e+2", UnixEpoch().Add(130 * time.Second)},
{"1.6E-10", UnixEpoch()}, // This is so small, it should get truncated into the UnixEpoch
{"2E-6", UnixEpoch().Add(2 * time.Microsecond)},
{"1.289345e9", UnixEpoch().Add(1289345000 * time.Second)},
{"1e-9", UnixEpoch().Add(time.Nanosecond)},
}
for _, tc := range testCases {
t.Run(tc.text, func(subT *testing.T) {
var rehydrated UnixTime
if err := json.Unmarshal([]byte(tc.text), &rehydrated); err != nil {
subT.Error(err)
return
}
if time.Time(rehydrated) != tc.expected {
subT.Logf("\ngot: \t%v\nwant:\t%v\ndiff:\t%v", time.Time(rehydrated), tc.expected, time.Time(rehydrated).Sub(tc.expected))
subT.Fail()
}
})
}
}
func TestUnixTime_JSONRoundTrip(t *testing.T) {
testCases := []time.Time{
UnixEpoch(),
time.Date(2005, time.November, 5, 0, 0, 0, 0, time.UTC), // The day V for Vendetta (film) was released.
UnixEpoch().Add(-6 * time.Second),
UnixEpoch().Add(800 * time.Hour),
UnixEpoch().Add(time.Nanosecond),
time.Date(2015, time.September, 05, 4, 30, 12, 9992, time.UTC),
}
for _, tc := range testCases {
t.Run(tc.String(), func(subT *testing.T) {
subject := UnixTime(tc)
var marshaled []byte
if temp, err := json.Marshal(subject); err == nil {
marshaled = temp
} else {
subT.Error(err)
return
}
var unmarshaled UnixTime
if err := json.Unmarshal(marshaled, &unmarshaled); err != nil {
subT.Error(err)
}
actual := time.Time(unmarshaled)
diff := actual.Sub(tc)
subT.Logf("\ngot :\t%s\nwant:\t%s\ndiff:\t%s", actual.String(), tc.String(), diff.String())
if diff > time.Duration(100) { // We lose some precision be working in floats. We shouldn't lose more than 100 nanoseconds.
subT.Fail()
}
})
}
}
func TestUnixTime_MarshalBinary(t *testing.T) {
testCases := []struct {
expected int64
subject time.Time
}{
{0, UnixEpoch()},
{-15 * int64(time.Second), UnixEpoch().Add(-15 * time.Second)},
{54, UnixEpoch().Add(54 * time.Nanosecond)},
}
for _, tc := range testCases {
t.Run("", func(subT *testing.T) {
var marshaled []byte
if temp, err := UnixTime(tc.subject).MarshalBinary(); err == nil {
marshaled = temp
} else {
subT.Error(err)
return
}
var unmarshaled int64
if err := binary.Read(bytes.NewReader(marshaled), binary.LittleEndian, &unmarshaled); err != nil {
subT.Error(err)
return
}
if unmarshaled != tc.expected {
subT.Logf("\ngot: \t%d\nwant:\t%d", unmarshaled, tc.expected)
subT.Fail()
}
})
}
}
func TestUnixTime_BinaryRoundTrip(t *testing.T) {
testCases := []time.Time{
UnixEpoch(),
UnixEpoch().Add(800 * time.Minute),
UnixEpoch().Add(7 * time.Hour),
UnixEpoch().Add(-1 * time.Nanosecond),
}
for _, tc := range testCases {
t.Run(tc.String(), func(subT *testing.T) {
original := UnixTime(tc)
var marshaled []byte
if temp, err := original.MarshalBinary(); err == nil {
marshaled = temp
} else {
subT.Error(err)
return
}
var traveled UnixTime
if err := traveled.UnmarshalBinary(marshaled); err != nil {
subT.Error(err)
return
}
if traveled != original {
subT.Logf("\ngot: \t%s\nwant:\t%s", time.Time(original).String(), time.Time(traveled).String())
subT.Fail()
}
})
}
}
func TestUnixTime_MarshalText(t *testing.T) {
testCases := []time.Time{
UnixEpoch(),
UnixEpoch().Add(45 * time.Second),
UnixEpoch().Add(time.Nanosecond),
UnixEpoch().Add(-100000 * time.Second),
}
for _, tc := range testCases {
expected, _ := tc.MarshalText()
t.Run("", func(subT *testing.T) {
var marshaled []byte
if temp, err := UnixTime(tc).MarshalText(); err == nil {
marshaled = temp
} else {
subT.Error(err)
return
}
if string(marshaled) != string(expected) {
subT.Logf("\ngot: \t%s\nwant:\t%s", string(marshaled), string(expected))
subT.Fail()
}
})
}
}
func TestUnixTime_TextRoundTrip(t *testing.T) {
testCases := []time.Time{
UnixEpoch(),
UnixEpoch().Add(-1 * time.Nanosecond),
UnixEpoch().Add(1 * time.Nanosecond),
time.Date(2017, time.April, 17, 21, 00, 00, 00, time.UTC),
}
for _, tc := range testCases {
t.Run(tc.String(), func(subT *testing.T) {
unixTC := UnixTime(tc)
var marshaled []byte
if temp, err := unixTC.MarshalText(); err == nil {
marshaled = temp
} else {
subT.Error(err)
return
}
var unmarshaled UnixTime
if err := unmarshaled.UnmarshalText(marshaled); err != nil {
subT.Error(err)
return
}
if unmarshaled != unixTC {
t.Logf("\ngot: \t%s\nwant:\t%s", time.Time(unmarshaled).String(), tc.String())
t.Fail()
}
})
}
}

View file

@ -1,202 +0,0 @@
package autorest
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"fmt"
"net/http"
"reflect"
"regexp"
"testing"
)
func TestNewErrorWithError_AssignsPackageType(t *testing.T) {
e := NewErrorWithError(fmt.Errorf("original"), "packageType", "method", nil, "message")
if e.PackageType != "packageType" {
t.Fatalf("autorest: Error failed to set package type -- expected %v, received %v", "packageType", e.PackageType)
}
}
func TestNewErrorWithError_AssignsMethod(t *testing.T) {
e := NewErrorWithError(fmt.Errorf("original"), "packageType", "method", nil, "message")
if e.Method != "method" {
t.Fatalf("autorest: Error failed to set method -- expected %v, received %v", "method", e.Method)
}
}
func TestNewErrorWithError_AssignsMessage(t *testing.T) {
e := NewErrorWithError(fmt.Errorf("original"), "packageType", "method", nil, "message")
if e.Message != "message" {
t.Fatalf("autorest: Error failed to set message -- expected %v, received %v", "message", e.Message)
}
}
func TestNewErrorWithError_AssignsUndefinedStatusCodeIfRespNil(t *testing.T) {
e := NewErrorWithError(nil, "packageType", "method", nil, "message")
if e.StatusCode != UndefinedStatusCode {
t.Fatalf("autorest: Error failed to set status code -- expected %v, received %v", UndefinedStatusCode, e.StatusCode)
}
}
func TestNewErrorWithError_AssignsStatusCode(t *testing.T) {
e := NewErrorWithError(fmt.Errorf("original"), "packageType", "method", &http.Response{
StatusCode: http.StatusBadRequest,
Status: http.StatusText(http.StatusBadRequest)}, "message")
if e.StatusCode != http.StatusBadRequest {
t.Fatalf("autorest: Error failed to set status code -- expected %v, received %v", http.StatusBadRequest, e.StatusCode)
}
}
func TestNewErrorWithError_AcceptsArgs(t *testing.T) {
e := NewErrorWithError(fmt.Errorf("original"), "packageType", "method", nil, "message %s", "arg")
if matched, _ := regexp.MatchString(`.*arg.*`, e.Message); !matched {
t.Fatalf("autorest: Error failed to apply message arguments -- expected %v, received %v",
`.*arg.*`, e.Message)
}
}
func TestNewErrorWithError_AssignsError(t *testing.T) {
err := fmt.Errorf("original")
e := NewErrorWithError(err, "packageType", "method", nil, "message")
if e.Original != err {
t.Fatalf("autorest: Error failed to set error -- expected %v, received %v", err, e.Original)
}
}
func TestNewErrorWithResponse_ContainsStatusCode(t *testing.T) {
e := NewErrorWithResponse("packageType", "method", &http.Response{
StatusCode: http.StatusBadRequest,
Status: http.StatusText(http.StatusBadRequest)}, "message")
if e.StatusCode != http.StatusBadRequest {
t.Fatalf("autorest: Error failed to set status code -- expected %v, received %v", http.StatusBadRequest, e.StatusCode)
}
}
func TestNewErrorWithResponse_nilResponse_ReportsUndefinedStatusCode(t *testing.T) {
e := NewErrorWithResponse("packageType", "method", nil, "message")
if e.StatusCode != UndefinedStatusCode {
t.Fatalf("autorest: Error failed to set status code -- expected %v, received %v", UndefinedStatusCode, e.StatusCode)
}
}
func TestNewErrorWithResponse_Forwards(t *testing.T) {
e1 := NewError("packageType", "method", "message %s", "arg")
e2 := NewErrorWithResponse("packageType", "method", nil, "message %s", "arg")
if !reflect.DeepEqual(e1, e2) {
t.Fatal("autorest: NewError did not return an error equivelent to NewErrorWithError")
}
}
func TestNewErrorWithError_Forwards(t *testing.T) {
e1 := NewError("packageType", "method", "message %s", "arg")
e2 := NewErrorWithError(nil, "packageType", "method", nil, "message %s", "arg")
if !reflect.DeepEqual(e1, e2) {
t.Fatal("autorest: NewError did not return an error equivelent to NewErrorWithError")
}
}
func TestNewErrorWithError_DoesNotWrapADetailedError(t *testing.T) {
e1 := NewError("packageType1", "method1", "message1 %s", "arg1")
e2 := NewErrorWithError(e1, "packageType2", "method2", nil, "message2 %s", "arg2")
if !reflect.DeepEqual(e1, e2) {
t.Fatalf("autorest: NewErrorWithError incorrectly wrapped a DetailedError -- expected %v, received %v", e1, e2)
}
}
func TestNewErrorWithError_WrapsAnError(t *testing.T) {
e1 := fmt.Errorf("Inner Error")
var e2 interface{} = NewErrorWithError(e1, "packageType", "method", nil, "message")
if _, ok := e2.(DetailedError); !ok {
t.Fatalf("autorest: NewErrorWithError failed to wrap a standard error -- received %T", e2)
}
}
func TestDetailedError(t *testing.T) {
err := fmt.Errorf("original")
e := NewErrorWithError(err, "packageType", "method", nil, "message")
if matched, _ := regexp.MatchString(`.*original.*`, e.Error()); !matched {
t.Fatalf("autorest: Error#Error failed to return original error message -- expected %v, received %v",
`.*original.*`, e.Error())
}
}
func TestDetailedErrorConstainsPackageType(t *testing.T) {
e := NewErrorWithError(fmt.Errorf("original"), "packageType", "method", nil, "message")
if matched, _ := regexp.MatchString(`.*packageType.*`, e.Error()); !matched {
t.Fatalf("autorest: Error#String failed to include PackageType -- expected %v, received %v",
`.*packageType.*`, e.Error())
}
}
func TestDetailedErrorConstainsMethod(t *testing.T) {
e := NewErrorWithError(fmt.Errorf("original"), "packageType", "method", nil, "message")
if matched, _ := regexp.MatchString(`.*method.*`, e.Error()); !matched {
t.Fatalf("autorest: Error#String failed to include Method -- expected %v, received %v",
`.*method.*`, e.Error())
}
}
func TestDetailedErrorConstainsMessage(t *testing.T) {
e := NewErrorWithError(fmt.Errorf("original"), "packageType", "method", nil, "message")
if matched, _ := regexp.MatchString(`.*message.*`, e.Error()); !matched {
t.Fatalf("autorest: Error#String failed to include Message -- expected %v, received %v",
`.*message.*`, e.Error())
}
}
func TestDetailedErrorConstainsStatusCode(t *testing.T) {
e := NewErrorWithError(fmt.Errorf("original"), "packageType", "method", &http.Response{
StatusCode: http.StatusBadRequest,
Status: http.StatusText(http.StatusBadRequest)}, "message")
if matched, _ := regexp.MatchString(`.*400.*`, e.Error()); !matched {
t.Fatalf("autorest: Error#String failed to include Status Code -- expected %v, received %v",
`.*400.*`, e.Error())
}
}
func TestDetailedErrorConstainsOriginal(t *testing.T) {
e := NewErrorWithError(fmt.Errorf("original"), "packageType", "method", nil, "message")
if matched, _ := regexp.MatchString(`.*original.*`, e.Error()); !matched {
t.Fatalf("autorest: Error#String failed to include Original error -- expected %v, received %v",
`.*original.*`, e.Error())
}
}
func TestDetailedErrorSkipsOriginal(t *testing.T) {
e := NewError("packageType", "method", "message")
if matched, _ := regexp.MatchString(`.*Original.*`, e.Error()); matched {
t.Fatalf("autorest: Error#String included missing Original error -- unexpected %v, received %v",
`.*Original.*`, e.Error())
}
}

View file

@ -27,8 +27,9 @@ import (
)
const (
mimeTypeJSON = "application/json"
mimeTypeFormPost = "application/x-www-form-urlencoded"
mimeTypeJSON = "application/json"
mimeTypeOctetStream = "application/octet-stream"
mimeTypeFormPost = "application/x-www-form-urlencoded"
headerAuthorization = "Authorization"
headerContentType = "Content-Type"
@ -112,6 +113,28 @@ func WithHeader(header string, value string) PrepareDecorator {
}
}
// WithHeaders returns a PrepareDecorator that sets the specified HTTP headers of the http.Request to
// the passed value. It canonicalizes the passed headers name (via http.CanonicalHeaderKey) before
// adding them.
func WithHeaders(headers map[string]interface{}) PrepareDecorator {
h := ensureValueStrings(headers)
return func(p Preparer) Preparer {
return PreparerFunc(func(r *http.Request) (*http.Request, error) {
r, err := p.Prepare(r)
if err == nil {
if r.Header == nil {
r.Header = make(http.Header)
}
for name, value := range h {
r.Header.Set(http.CanonicalHeaderKey(name), value)
}
}
return r, err
})
}
}
// WithBearerAuthorization returns a PrepareDecorator that adds an HTTP Authorization header whose
// value is "Bearer " followed by the supplied token.
func WithBearerAuthorization(token string) PrepareDecorator {
@ -142,6 +165,11 @@ func AsJSON() PrepareDecorator {
return AsContentType(mimeTypeJSON)
}
// AsOctetStream returns a PrepareDecorator that adds the "application/octet-stream" Content-Type header.
func AsOctetStream() PrepareDecorator {
return AsContentType(mimeTypeOctetStream)
}
// WithMethod returns a PrepareDecorator that sets the HTTP method of the passed request. The
// decorator does not validate that the passed method string is a known HTTP method.
func WithMethod(method string) PrepareDecorator {
@ -215,6 +243,11 @@ func WithFormData(v url.Values) PrepareDecorator {
r, err := p.Prepare(r)
if err == nil {
s := v.Encode()
if r.Header == nil {
r.Header = make(http.Header)
}
r.Header.Set(http.CanonicalHeaderKey(headerContentType), mimeTypeFormPost)
r.ContentLength = int64(len(s))
r.Body = ioutil.NopCloser(strings.NewReader(s))
}
@ -430,11 +463,16 @@ func WithQueryParameters(queryParameters map[string]interface{}) PrepareDecorato
if r.URL == nil {
return r, NewError("autorest", "WithQueryParameters", "Invoked with a nil URL")
}
v := r.URL.Query()
for key, value := range parameters {
v.Add(key, value)
d, err := url.QueryUnescape(value)
if err != nil {
return r, err
}
v.Add(key, d)
}
r.URL.RawQuery = createQuery(v)
r.URL.RawQuery = v.Encode()
}
return r, err
})

View file

@ -1,766 +0,0 @@
package autorest
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"fmt"
"io/ioutil"
"net/http"
"net/url"
"reflect"
"strconv"
"strings"
"testing"
"github.com/Azure/go-autorest/autorest/mocks"
)
// PrepareDecorators wrap and invoke a Preparer. Most often, the decorator invokes the passed
// Preparer and decorates the response.
func ExamplePrepareDecorator() {
path := "a/b/c/"
pd := func() PrepareDecorator {
return func(p Preparer) Preparer {
return PreparerFunc(func(r *http.Request) (*http.Request, error) {
r, err := p.Prepare(r)
if err == nil {
if r.URL == nil {
return r, fmt.Errorf("ERROR: URL is not set")
}
r.URL.Path += path
}
return r, err
})
}
}
r, _ := Prepare(&http.Request{},
WithBaseURL("https://microsoft.com/"),
pd())
fmt.Printf("Path is %s\n", r.URL)
// Output: Path is https://microsoft.com/a/b/c/
}
// PrepareDecorators may also modify and then invoke the Preparer.
func ExamplePrepareDecorator_pre() {
pd := func() PrepareDecorator {
return func(p Preparer) Preparer {
return PreparerFunc(func(r *http.Request) (*http.Request, error) {
r.Header.Add(http.CanonicalHeaderKey("ContentType"), "application/json")
return p.Prepare(r)
})
}
}
r, _ := Prepare(&http.Request{Header: http.Header{}},
pd())
fmt.Printf("ContentType is %s\n", r.Header.Get("ContentType"))
// Output: ContentType is application/json
}
// Create a sequence of three Preparers that build up the URL path.
func ExampleCreatePreparer() {
p := CreatePreparer(
WithBaseURL("https://microsoft.com/"),
WithPath("a"),
WithPath("b"),
WithPath("c"))
r, err := p.Prepare(&http.Request{})
if err != nil {
fmt.Printf("ERROR: %v\n", err)
} else {
fmt.Println(r.URL)
}
// Output: https://microsoft.com/a/b/c
}
// Create and apply separate Preparers
func ExampleCreatePreparer_multiple() {
params := map[string]interface{}{
"param1": "a",
"param2": "c",
}
p1 := CreatePreparer(WithBaseURL("https://microsoft.com/"))
p2 := CreatePreparer(WithPathParameters("/{param1}/b/{param2}/", params))
r, err := p1.Prepare(&http.Request{})
if err != nil {
fmt.Printf("ERROR: %v\n", err)
}
r, err = p2.Prepare(r)
if err != nil {
fmt.Printf("ERROR: %v\n", err)
} else {
fmt.Println(r.URL)
}
// Output: https://microsoft.com/a/b/c/
}
// Create and chain separate Preparers
func ExampleCreatePreparer_chain() {
params := map[string]interface{}{
"param1": "a",
"param2": "c",
}
p := CreatePreparer(WithBaseURL("https://microsoft.com/"))
p = DecoratePreparer(p, WithPathParameters("/{param1}/b/{param2}/", params))
r, err := p.Prepare(&http.Request{})
if err != nil {
fmt.Printf("ERROR: %v\n", err)
} else {
fmt.Println(r.URL)
}
// Output: https://microsoft.com/a/b/c/
}
// Create and prepare an http.Request in one call
func ExamplePrepare() {
r, err := Prepare(&http.Request{},
AsGet(),
WithBaseURL("https://microsoft.com/"),
WithPath("a/b/c/"))
if err != nil {
fmt.Printf("ERROR: %v\n", err)
} else {
fmt.Printf("%s %s", r.Method, r.URL)
}
// Output: GET https://microsoft.com/a/b/c/
}
// Create a request for a supplied base URL and path
func ExampleWithBaseURL() {
r, err := Prepare(&http.Request{},
WithBaseURL("https://microsoft.com/a/b/c/"))
if err != nil {
fmt.Printf("ERROR: %v\n", err)
} else {
fmt.Println(r.URL)
}
// Output: https://microsoft.com/a/b/c/
}
func ExampleWithBaseURL_second() {
_, err := Prepare(&http.Request{}, WithBaseURL(":"))
fmt.Println(err)
// Output: parse :: missing protocol scheme
}
func ExampleWithCustomBaseURL() {
r, err := Prepare(&http.Request{},
WithCustomBaseURL("https://{account}.{service}.core.windows.net/",
map[string]interface{}{
"account": "myaccount",
"service": "blob",
}))
if err != nil {
fmt.Printf("ERROR: %v\n", err)
} else {
fmt.Println(r.URL)
}
// Output: https://myaccount.blob.core.windows.net/
}
func ExampleWithCustomBaseURL_second() {
_, err := Prepare(&http.Request{},
WithCustomBaseURL(":", map[string]interface{}{}))
fmt.Println(err)
// Output: parse :: missing protocol scheme
}
// Create a request with a custom HTTP header
func ExampleWithHeader() {
r, err := Prepare(&http.Request{},
WithBaseURL("https://microsoft.com/a/b/c/"),
WithHeader("x-foo", "bar"))
if err != nil {
fmt.Printf("ERROR: %v\n", err)
} else {
fmt.Printf("Header %s=%s\n", "x-foo", r.Header.Get("x-foo"))
}
// Output: Header x-foo=bar
}
// Create a request whose Body is the JSON encoding of a structure
func ExampleWithFormData() {
v := url.Values{}
v.Add("name", "Rob Pike")
v.Add("age", "42")
r, err := Prepare(&http.Request{},
WithFormData(v))
if err != nil {
fmt.Printf("ERROR: %v\n", err)
}
b, err := ioutil.ReadAll(r.Body)
if err != nil {
fmt.Printf("ERROR: %v\n", err)
} else {
fmt.Printf("Request Body contains %s\n", string(b))
}
// Output: Request Body contains age=42&name=Rob+Pike
}
// Create a request whose Body is the JSON encoding of a structure
func ExampleWithJSON() {
t := mocks.T{Name: "Rob Pike", Age: 42}
r, err := Prepare(&http.Request{},
WithJSON(&t))
if err != nil {
fmt.Printf("ERROR: %v\n", err)
}
b, err := ioutil.ReadAll(r.Body)
if err != nil {
fmt.Printf("ERROR: %v\n", err)
} else {
fmt.Printf("Request Body contains %s\n", string(b))
}
// Output: Request Body contains {"name":"Rob Pike","age":42}
}
// Create a request from a path with escaped parameters
func ExampleWithEscapedPathParameters() {
params := map[string]interface{}{
"param1": "a b c",
"param2": "d e f",
}
r, err := Prepare(&http.Request{},
WithBaseURL("https://microsoft.com/"),
WithEscapedPathParameters("/{param1}/b/{param2}/", params))
if err != nil {
fmt.Printf("ERROR: %v\n", err)
} else {
fmt.Println(r.URL)
}
// Output: https://microsoft.com/a+b+c/b/d+e+f/
}
// Create a request from a path with parameters
func ExampleWithPathParameters() {
params := map[string]interface{}{
"param1": "a",
"param2": "c",
}
r, err := Prepare(&http.Request{},
WithBaseURL("https://microsoft.com/"),
WithPathParameters("/{param1}/b/{param2}/", params))
if err != nil {
fmt.Printf("ERROR: %v\n", err)
} else {
fmt.Println(r.URL)
}
// Output: https://microsoft.com/a/b/c/
}
// Create a request with query parameters
func ExampleWithQueryParameters() {
params := map[string]interface{}{
"q1": "value1",
"q2": "value2",
}
r, err := Prepare(&http.Request{},
WithBaseURL("https://microsoft.com/"),
WithPath("/a/b/c/"),
WithQueryParameters(params))
if err != nil {
fmt.Printf("ERROR: %v\n", err)
} else {
fmt.Println(r.URL)
}
// Output: https://microsoft.com/a/b/c/?q1=value1&q2=value2
}
func TestWithCustomBaseURL(t *testing.T) {
r, err := Prepare(&http.Request{}, WithCustomBaseURL("https://{account}.{service}.core.windows.net/",
map[string]interface{}{
"account": "myaccount",
"service": "blob",
}))
if err != nil {
t.Fatalf("autorest: WithCustomBaseURL should not fail")
}
if r.URL.String() != "https://myaccount.blob.core.windows.net/" {
t.Fatalf("autorest: WithCustomBaseURL expected https://myaccount.blob.core.windows.net/, got %s", r.URL)
}
}
func TestWithCustomBaseURLwithInvalidURL(t *testing.T) {
_, err := Prepare(&http.Request{}, WithCustomBaseURL("hello/{account}.{service}.core.windows.net/",
map[string]interface{}{
"account": "myaccount",
"service": "blob",
}))
if err == nil {
t.Fatalf("autorest: WithCustomBaseURL should fail fo URL parse error")
}
}
func TestWithPathWithInvalidPath(t *testing.T) {
p := "path%2*end"
if _, err := Prepare(&http.Request{}, WithBaseURL("https://microsoft.com/"), WithPath(p)); err == nil {
t.Fatalf("autorest: WithPath should fail for invalid URL escape error for path '%v' ", p)
}
}
func TestWithPathParametersWithInvalidPath(t *testing.T) {
p := "path%2*end"
m := map[string]interface{}{
"path1": p,
}
if _, err := Prepare(&http.Request{}, WithBaseURL("https://microsoft.com/"), WithPathParameters("/{path1}/", m)); err == nil {
t.Fatalf("autorest: WithPath should fail for invalid URL escape for path '%v' ", p)
}
}
func TestCreatePreparerDoesNotModify(t *testing.T) {
r1 := &http.Request{}
p := CreatePreparer()
r2, err := p.Prepare(r1)
if err != nil {
t.Fatalf("autorest: CreatePreparer failed (%v)", err)
}
if !reflect.DeepEqual(r1, r2) {
t.Fatalf("autorest: CreatePreparer without decorators modified the request")
}
}
func TestCreatePreparerRunsDecoratorsInOrder(t *testing.T) {
p := CreatePreparer(WithBaseURL("https://microsoft.com/"), WithPath("1"), WithPath("2"), WithPath("3"))
r, err := p.Prepare(&http.Request{})
if err != nil {
t.Fatalf("autorest: CreatePreparer failed (%v)", err)
}
if r.URL.String() != "https:/1/2/3" && r.URL.Host != "microsoft.com" {
t.Fatalf("autorest: CreatePreparer failed to run decorators in order")
}
}
func TestAsContentType(t *testing.T) {
r, err := Prepare(mocks.NewRequest(), AsContentType("application/text"))
if err != nil {
fmt.Printf("ERROR: %v", err)
}
if r.Header.Get(headerContentType) != "application/text" {
t.Fatalf("autorest: AsContentType failed to add header (%s=%s)", headerContentType, r.Header.Get(headerContentType))
}
}
func TestAsFormURLEncoded(t *testing.T) {
r, err := Prepare(mocks.NewRequest(), AsFormURLEncoded())
if err != nil {
fmt.Printf("ERROR: %v", err)
}
if r.Header.Get(headerContentType) != mimeTypeFormPost {
t.Fatalf("autorest: AsFormURLEncoded failed to add header (%s=%s)", headerContentType, r.Header.Get(headerContentType))
}
}
func TestAsJSON(t *testing.T) {
r, err := Prepare(mocks.NewRequest(), AsJSON())
if err != nil {
fmt.Printf("ERROR: %v", err)
}
if r.Header.Get(headerContentType) != mimeTypeJSON {
t.Fatalf("autorest: AsJSON failed to add header (%s=%s)", headerContentType, r.Header.Get(headerContentType))
}
}
func TestWithNothing(t *testing.T) {
r1 := mocks.NewRequest()
r2, err := Prepare(r1, WithNothing())
if err != nil {
t.Fatalf("autorest: WithNothing returned an unexpected error (%v)", err)
}
if !reflect.DeepEqual(r1, r2) {
t.Fatal("azure: WithNothing modified the passed HTTP Request")
}
}
func TestWithBearerAuthorization(t *testing.T) {
r, err := Prepare(mocks.NewRequest(), WithBearerAuthorization("SOME-TOKEN"))
if err != nil {
fmt.Printf("ERROR: %v", err)
}
if r.Header.Get(headerAuthorization) != "Bearer SOME-TOKEN" {
t.Fatalf("autorest: WithBearerAuthorization failed to add header (%s=%s)", headerAuthorization, r.Header.Get(headerAuthorization))
}
}
func TestWithUserAgent(t *testing.T) {
ua := "User Agent Go"
r, err := Prepare(mocks.NewRequest(), WithUserAgent(ua))
if err != nil {
fmt.Printf("ERROR: %v", err)
}
if r.UserAgent() != ua || r.Header.Get(headerUserAgent) != ua {
t.Fatalf("autorest: WithUserAgent failed to add header (%s=%s)", headerUserAgent, r.Header.Get(headerUserAgent))
}
}
func TestWithMethod(t *testing.T) {
r, _ := Prepare(mocks.NewRequest(), WithMethod("HEAD"))
if r.Method != "HEAD" {
t.Fatal("autorest: WithMethod failed to set HTTP method header")
}
}
func TestAsDelete(t *testing.T) {
r, _ := Prepare(mocks.NewRequest(), AsDelete())
if r.Method != "DELETE" {
t.Fatal("autorest: AsDelete failed to set HTTP method header to DELETE")
}
}
func TestAsGet(t *testing.T) {
r, _ := Prepare(mocks.NewRequest(), AsGet())
if r.Method != "GET" {
t.Fatal("autorest: AsGet failed to set HTTP method header to GET")
}
}
func TestAsHead(t *testing.T) {
r, _ := Prepare(mocks.NewRequest(), AsHead())
if r.Method != "HEAD" {
t.Fatal("autorest: AsHead failed to set HTTP method header to HEAD")
}
}
func TestAsOptions(t *testing.T) {
r, _ := Prepare(mocks.NewRequest(), AsOptions())
if r.Method != "OPTIONS" {
t.Fatal("autorest: AsOptions failed to set HTTP method header to OPTIONS")
}
}
func TestAsPatch(t *testing.T) {
r, _ := Prepare(mocks.NewRequest(), AsPatch())
if r.Method != "PATCH" {
t.Fatal("autorest: AsPatch failed to set HTTP method header to PATCH")
}
}
func TestAsPost(t *testing.T) {
r, _ := Prepare(mocks.NewRequest(), AsPost())
if r.Method != "POST" {
t.Fatal("autorest: AsPost failed to set HTTP method header to POST")
}
}
func TestAsPut(t *testing.T) {
r, _ := Prepare(mocks.NewRequest(), AsPut())
if r.Method != "PUT" {
t.Fatal("autorest: AsPut failed to set HTTP method header to PUT")
}
}
func TestPrepareWithNullRequest(t *testing.T) {
_, err := Prepare(nil)
if err == nil {
t.Fatal("autorest: Prepare failed to return an error when given a null http.Request")
}
}
func TestWithFormDataSetsContentLength(t *testing.T) {
v := url.Values{}
v.Add("name", "Rob Pike")
v.Add("age", "42")
r, err := Prepare(&http.Request{},
WithFormData(v))
if err != nil {
t.Fatalf("autorest: WithFormData failed with error (%v)", err)
}
b, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("autorest: WithFormData failed with error (%v)", err)
}
expected := "name=Rob+Pike&age=42"
if !(string(b) == "name=Rob+Pike&age=42" || string(b) == "age=42&name=Rob+Pike") {
t.Fatalf("autorest:WithFormData failed to return correct string got (%v), expected (%v)", string(b), expected)
}
if r.ContentLength != int64(len(b)) {
t.Fatalf("autorest:WithFormData set Content-Length to %v, expected %v", r.ContentLength, len(b))
}
}
func TestWithMultiPartFormDataSetsContentLength(t *testing.T) {
v := map[string]interface{}{
"file": ioutil.NopCloser(strings.NewReader("Hello Gopher")),
"age": "42",
}
r, err := Prepare(&http.Request{},
WithMultiPartFormData(v))
if err != nil {
t.Fatalf("autorest: WithMultiPartFormData failed with error (%v)", err)
}
b, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("autorest: WithMultiPartFormData failed with error (%v)", err)
}
if r.ContentLength != int64(len(b)) {
t.Fatalf("autorest:WithMultiPartFormData set Content-Length to %v, expected %v", r.ContentLength, len(b))
}
}
func TestWithMultiPartFormDataWithNoFile(t *testing.T) {
v := map[string]interface{}{
"file": "no file",
"age": "42",
}
r, err := Prepare(&http.Request{},
WithMultiPartFormData(v))
if err != nil {
t.Fatalf("autorest: WithMultiPartFormData failed with error (%v)", err)
}
b, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("autorest: WithMultiPartFormData failed with error (%v)", err)
}
if r.ContentLength != int64(len(b)) {
t.Fatalf("autorest:WithMultiPartFormData set Content-Length to %v, expected %v", r.ContentLength, len(b))
}
}
func TestWithFile(t *testing.T) {
r, err := Prepare(&http.Request{},
WithFile(ioutil.NopCloser(strings.NewReader("Hello Gopher"))))
if err != nil {
t.Fatalf("autorest: WithFile failed with error (%v)", err)
}
b, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("autorest: WithFile failed with error (%v)", err)
}
if r.ContentLength != int64(len(b)) {
t.Fatalf("autorest:WithFile set Content-Length to %v, expected %v", r.ContentLength, len(b))
}
}
func TestWithBool_SetsTheBody(t *testing.T) {
r, err := Prepare(&http.Request{},
WithBool(false))
if err != nil {
t.Fatalf("autorest: WithBool failed with error (%v)", err)
}
s, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("autorest: WithBool failed with error (%v)", err)
}
if r.ContentLength != int64(len(fmt.Sprintf("%v", false))) {
t.Fatalf("autorest: WithBool set Content-Length to %v, expected %v", r.ContentLength, int64(len(fmt.Sprintf("%v", false))))
}
v, err := strconv.ParseBool(string(s))
if err != nil || v {
t.Fatalf("autorest: WithBool incorrectly encoded the boolean as %v", s)
}
}
func TestWithFloat32_SetsTheBody(t *testing.T) {
r, err := Prepare(&http.Request{},
WithFloat32(42.0))
if err != nil {
t.Fatalf("autorest: WithFloat32 failed with error (%v)", err)
}
s, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("autorest: WithFloat32 failed with error (%v)", err)
}
if r.ContentLength != int64(len(fmt.Sprintf("%v", 42.0))) {
t.Fatalf("autorest: WithFloat32 set Content-Length to %v, expected %v", r.ContentLength, int64(len(fmt.Sprintf("%v", 42.0))))
}
v, err := strconv.ParseFloat(string(s), 32)
if err != nil || float32(v) != float32(42.0) {
t.Fatalf("autorest: WithFloat32 incorrectly encoded the boolean as %v", s)
}
}
func TestWithFloat64_SetsTheBody(t *testing.T) {
r, err := Prepare(&http.Request{},
WithFloat64(42.0))
if err != nil {
t.Fatalf("autorest: WithFloat64 failed with error (%v)", err)
}
s, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("autorest: WithFloat64 failed with error (%v)", err)
}
if r.ContentLength != int64(len(fmt.Sprintf("%v", 42.0))) {
t.Fatalf("autorest: WithFloat64 set Content-Length to %v, expected %v", r.ContentLength, int64(len(fmt.Sprintf("%v", 42.0))))
}
v, err := strconv.ParseFloat(string(s), 64)
if err != nil || v != float64(42.0) {
t.Fatalf("autorest: WithFloat64 incorrectly encoded the boolean as %v", s)
}
}
func TestWithInt32_SetsTheBody(t *testing.T) {
r, err := Prepare(&http.Request{},
WithInt32(42))
if err != nil {
t.Fatalf("autorest: WithInt32 failed with error (%v)", err)
}
s, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("autorest: WithInt32 failed with error (%v)", err)
}
if r.ContentLength != int64(len(fmt.Sprintf("%v", 42))) {
t.Fatalf("autorest: WithInt32 set Content-Length to %v, expected %v", r.ContentLength, int64(len(fmt.Sprintf("%v", 42))))
}
v, err := strconv.ParseInt(string(s), 10, 32)
if err != nil || int32(v) != int32(42) {
t.Fatalf("autorest: WithInt32 incorrectly encoded the boolean as %v", s)
}
}
func TestWithInt64_SetsTheBody(t *testing.T) {
r, err := Prepare(&http.Request{},
WithInt64(42))
if err != nil {
t.Fatalf("autorest: WithInt64 failed with error (%v)", err)
}
s, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("autorest: WithInt64 failed with error (%v)", err)
}
if r.ContentLength != int64(len(fmt.Sprintf("%v", 42))) {
t.Fatalf("autorest: WithInt64 set Content-Length to %v, expected %v", r.ContentLength, int64(len(fmt.Sprintf("%v", 42))))
}
v, err := strconv.ParseInt(string(s), 10, 64)
if err != nil || v != int64(42) {
t.Fatalf("autorest: WithInt64 incorrectly encoded the boolean as %v", s)
}
}
func TestWithString_SetsTheBody(t *testing.T) {
r, err := Prepare(&http.Request{},
WithString("value"))
if err != nil {
t.Fatalf("autorest: WithString failed with error (%v)", err)
}
s, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("autorest: WithString failed with error (%v)", err)
}
if r.ContentLength != int64(len("value")) {
t.Fatalf("autorest: WithString set Content-Length to %v, expected %v", r.ContentLength, int64(len("value")))
}
if string(s) != "value" {
t.Fatalf("autorest: WithString incorrectly encoded the string as %v", s)
}
}
func TestWithJSONSetsContentLength(t *testing.T) {
r, err := Prepare(&http.Request{},
WithJSON(&mocks.T{Name: "Rob Pike", Age: 42}))
if err != nil {
t.Fatalf("autorest: WithJSON failed with error (%v)", err)
}
b, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("autorest: WithJSON failed with error (%v)", err)
}
if r.ContentLength != int64(len(b)) {
t.Fatalf("autorest:WithJSON set Content-Length to %v, expected %v", r.ContentLength, len(b))
}
}
func TestWithHeaderAllocatesHeaders(t *testing.T) {
r, err := Prepare(mocks.NewRequest(), WithHeader("x-foo", "bar"))
if err != nil {
t.Fatalf("autorest: WithHeader failed (%v)", err)
}
if r.Header.Get("x-foo") != "bar" {
t.Fatalf("autorest: WithHeader failed to add header (%s=%s)", "x-foo", r.Header.Get("x-foo"))
}
}
func TestWithPathCatchesNilURL(t *testing.T) {
_, err := Prepare(&http.Request{}, WithPath("a"))
if err == nil {
t.Fatalf("autorest: WithPath failed to catch a nil URL")
}
}
func TestWithEscapedPathParametersCatchesNilURL(t *testing.T) {
_, err := Prepare(&http.Request{}, WithEscapedPathParameters("", map[string]interface{}{"foo": "bar"}))
if err == nil {
t.Fatalf("autorest: WithEscapedPathParameters failed to catch a nil URL")
}
}
func TestWithPathParametersCatchesNilURL(t *testing.T) {
_, err := Prepare(&http.Request{}, WithPathParameters("", map[string]interface{}{"foo": "bar"}))
if err == nil {
t.Fatalf("autorest: WithPathParameters failed to catch a nil URL")
}
}
func TestWithQueryParametersCatchesNilURL(t *testing.T) {
_, err := Prepare(&http.Request{}, WithQueryParameters(map[string]interface{}{"foo": "bar"}))
if err == nil {
t.Fatalf("autorest: WithQueryParameters failed to catch a nil URL")
}
}
func TestModifyingExistingRequest(t *testing.T) {
r, err := Prepare(mocks.NewRequestForURL("https://bing.com"), WithPath("search"), WithQueryParameters(map[string]interface{}{"q": "golang"}))
if err != nil {
t.Fatalf("autorest: Preparing an existing request returned an error (%v)", err)
}
if r.URL.String() != "https:/search?q=golang" && r.URL.Host != "bing.com" {
t.Fatalf("autorest: Preparing an existing request failed (%s)", r.URL)
}
}

View file

@ -1,665 +0,0 @@
package autorest
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"bytes"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"reflect"
"strings"
"testing"
"github.com/Azure/go-autorest/autorest/mocks"
)
func ExampleWithErrorUnlessOK() {
r := mocks.NewResponse()
r.Request = mocks.NewRequest()
// Respond and leave the response body open (for a subsequent responder to close)
err := Respond(r,
WithErrorUnlessOK(),
ByDiscardingBody(),
ByClosingIfError())
if err == nil {
fmt.Printf("%s of %s returned HTTP 200", r.Request.Method, r.Request.URL)
// Complete handling the response and close the body
Respond(r,
ByDiscardingBody(),
ByClosing())
}
// Output: GET of https://microsoft.com/a/b/c/ returned HTTP 200
}
func ExampleByUnmarshallingJSON() {
c := `
{
"name" : "Rob Pike",
"age" : 42
}
`
type V struct {
Name string `json:"name"`
Age int `json:"age"`
}
v := &V{}
Respond(mocks.NewResponseWithContent(c),
ByUnmarshallingJSON(v),
ByClosing())
fmt.Printf("%s is %d years old\n", v.Name, v.Age)
// Output: Rob Pike is 42 years old
}
func ExampleByUnmarshallingXML() {
c := `<?xml version="1.0" encoding="UTF-8"?>
<Person>
<Name>Rob Pike</Name>
<Age>42</Age>
</Person>`
type V struct {
Name string `xml:"Name"`
Age int `xml:"Age"`
}
v := &V{}
Respond(mocks.NewResponseWithContent(c),
ByUnmarshallingXML(v),
ByClosing())
fmt.Printf("%s is %d years old\n", v.Name, v.Age)
// Output: Rob Pike is 42 years old
}
func TestCreateResponderDoesNotModify(t *testing.T) {
r1 := mocks.NewResponse()
r2 := mocks.NewResponse()
p := CreateResponder()
err := p.Respond(r1)
if err != nil {
t.Fatalf("autorest: CreateResponder failed (%v)", err)
}
if !reflect.DeepEqual(r1, r2) {
t.Fatalf("autorest: CreateResponder without decorators modified the response")
}
}
func TestCreateResponderRunsDecoratorsInOrder(t *testing.T) {
s := ""
d := func(n int) RespondDecorator {
return func(r Responder) Responder {
return ResponderFunc(func(resp *http.Response) error {
err := r.Respond(resp)
if err == nil {
s += fmt.Sprintf("%d", n)
}
return err
})
}
}
p := CreateResponder(d(1), d(2), d(3))
err := p.Respond(&http.Response{})
if err != nil {
t.Fatalf("autorest: Respond failed (%v)", err)
}
if s != "123" {
t.Fatalf("autorest: CreateResponder invoked decorators in an incorrect order; expected '123', received '%s'", s)
}
}
func TestByIgnoring(t *testing.T) {
r := mocks.NewResponse()
Respond(r,
(func() RespondDecorator {
return func(r Responder) Responder {
return ResponderFunc(func(r2 *http.Response) error {
r1 := mocks.NewResponse()
if !reflect.DeepEqual(r1, r2) {
t.Fatalf("autorest: ByIgnoring modified the HTTP Response -- received %v, expected %v", r2, r1)
}
return nil
})
}
})(),
ByIgnoring(),
ByClosing())
}
func TestByCopying_Copies(t *testing.T) {
r := mocks.NewResponseWithContent(jsonT)
b := &bytes.Buffer{}
err := Respond(r,
ByCopying(b),
ByUnmarshallingJSON(&mocks.T{}),
ByClosing())
if err != nil {
t.Fatalf("autorest: ByCopying returned an unexpected error -- %v", err)
}
if b.String() != jsonT {
t.Fatalf("autorest: ByCopying failed to copy the bytes read")
}
}
func TestByCopying_ReturnsNestedErrors(t *testing.T) {
r := mocks.NewResponseWithContent(jsonT)
r.Body.Close()
err := Respond(r,
ByCopying(&bytes.Buffer{}),
ByUnmarshallingJSON(&mocks.T{}),
ByClosing())
if err == nil {
t.Fatalf("autorest: ByCopying failed to return the expected error")
}
}
func TestByCopying_AcceptsNilReponse(t *testing.T) {
r := mocks.NewResponse()
Respond(r,
(func() RespondDecorator {
return func(r Responder) Responder {
return ResponderFunc(func(resp *http.Response) error {
resp.Body.Close()
r.Respond(nil)
return nil
})
}
})(),
ByCopying(&bytes.Buffer{}))
}
func TestByCopying_AcceptsNilBody(t *testing.T) {
r := mocks.NewResponse()
Respond(r,
(func() RespondDecorator {
return func(r Responder) Responder {
return ResponderFunc(func(resp *http.Response) error {
resp.Body.Close()
resp.Body = nil
r.Respond(resp)
return nil
})
}
})(),
ByCopying(&bytes.Buffer{}))
}
func TestByClosing(t *testing.T) {
r := mocks.NewResponse()
err := Respond(r, ByClosing())
if err != nil {
t.Fatalf("autorest: ByClosing failed (%v)", err)
}
if r.Body.(*mocks.Body).IsOpen() {
t.Fatalf("autorest: ByClosing did not close the response body")
}
}
func TestByClosingAcceptsNilResponse(t *testing.T) {
r := mocks.NewResponse()
Respond(r,
(func() RespondDecorator {
return func(r Responder) Responder {
return ResponderFunc(func(resp *http.Response) error {
resp.Body.Close()
r.Respond(nil)
return nil
})
}
})(),
ByClosing())
}
func TestByClosingAcceptsNilBody(t *testing.T) {
r := mocks.NewResponse()
Respond(r,
(func() RespondDecorator {
return func(r Responder) Responder {
return ResponderFunc(func(resp *http.Response) error {
resp.Body.Close()
resp.Body = nil
r.Respond(resp)
return nil
})
}
})(),
ByClosing())
}
func TestByClosingClosesEvenAfterErrors(t *testing.T) {
var e error
r := mocks.NewResponse()
Respond(r,
withErrorRespondDecorator(&e),
ByClosing())
if r.Body.(*mocks.Body).IsOpen() {
t.Fatalf("autorest: ByClosing did not close the response body after an error occurred")
}
}
func TestByClosingClosesReturnsNestedErrors(t *testing.T) {
var e error
r := mocks.NewResponse()
err := Respond(r,
withErrorRespondDecorator(&e),
ByClosing())
if err == nil || !reflect.DeepEqual(e, err) {
t.Fatalf("autorest: ByClosing failed to return a nested error")
}
}
func TestByClosingIfErrorAcceptsNilResponse(t *testing.T) {
var e error
r := mocks.NewResponse()
Respond(r,
withErrorRespondDecorator(&e),
(func() RespondDecorator {
return func(r Responder) Responder {
return ResponderFunc(func(resp *http.Response) error {
resp.Body.Close()
r.Respond(nil)
return nil
})
}
})(),
ByClosingIfError())
}
func TestByClosingIfErrorAcceptsNilBody(t *testing.T) {
var e error
r := mocks.NewResponse()
Respond(r,
withErrorRespondDecorator(&e),
(func() RespondDecorator {
return func(r Responder) Responder {
return ResponderFunc(func(resp *http.Response) error {
resp.Body.Close()
resp.Body = nil
r.Respond(resp)
return nil
})
}
})(),
ByClosingIfError())
}
func TestByClosingIfErrorClosesIfAnErrorOccurs(t *testing.T) {
var e error
r := mocks.NewResponse()
Respond(r,
withErrorRespondDecorator(&e),
ByClosingIfError())
if r.Body.(*mocks.Body).IsOpen() {
t.Fatalf("autorest: ByClosingIfError did not close the response body after an error occurred")
}
}
func TestByClosingIfErrorDoesNotClosesIfNoErrorOccurs(t *testing.T) {
r := mocks.NewResponse()
Respond(r,
ByClosingIfError())
if !r.Body.(*mocks.Body).IsOpen() {
t.Fatalf("autorest: ByClosingIfError closed the response body even though no error occurred")
}
}
func TestByDiscardingBody(t *testing.T) {
r := mocks.NewResponse()
err := Respond(r,
ByDiscardingBody())
if err != nil {
t.Fatalf("autorest: ByDiscardingBody failed (%v)", err)
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("autorest: Reading result of ByDiscardingBody failed (%v)", err)
}
if len(buf) != 0 {
t.Logf("autorest: Body was not empty after calling ByDiscardingBody.")
t.Fail()
}
}
func TestByDiscardingBodyAcceptsNilResponse(t *testing.T) {
var e error
r := mocks.NewResponse()
Respond(r,
withErrorRespondDecorator(&e),
(func() RespondDecorator {
return func(r Responder) Responder {
return ResponderFunc(func(resp *http.Response) error {
resp.Body.Close()
r.Respond(nil)
return nil
})
}
})(),
ByDiscardingBody())
}
func TestByDiscardingBodyAcceptsNilBody(t *testing.T) {
var e error
r := mocks.NewResponse()
Respond(r,
withErrorRespondDecorator(&e),
(func() RespondDecorator {
return func(r Responder) Responder {
return ResponderFunc(func(resp *http.Response) error {
resp.Body.Close()
resp.Body = nil
r.Respond(resp)
return nil
})
}
})(),
ByDiscardingBody())
}
func TestByUnmarshallingJSON(t *testing.T) {
v := &mocks.T{}
r := mocks.NewResponseWithContent(jsonT)
err := Respond(r,
ByUnmarshallingJSON(v),
ByClosing())
if err != nil {
t.Fatalf("autorest: ByUnmarshallingJSON failed (%v)", err)
}
if v.Name != "Rob Pike" || v.Age != 42 {
t.Fatalf("autorest: ByUnmarshallingJSON failed to properly unmarshal")
}
}
func TestByUnmarshallingJSON_HandlesReadErrors(t *testing.T) {
v := &mocks.T{}
r := mocks.NewResponseWithContent(jsonT)
r.Body.(*mocks.Body).Close()
err := Respond(r,
ByUnmarshallingJSON(v),
ByClosing())
if err == nil {
t.Fatalf("autorest: ByUnmarshallingJSON failed to receive / respond to read error")
}
}
func TestByUnmarshallingJSONIncludesJSONInErrors(t *testing.T) {
v := &mocks.T{}
j := jsonT[0 : len(jsonT)-2]
r := mocks.NewResponseWithContent(j)
err := Respond(r,
ByUnmarshallingJSON(v),
ByClosing())
if err == nil || !strings.Contains(err.Error(), j) {
t.Fatalf("autorest: ByUnmarshallingJSON failed to return JSON in error (%v)", err)
}
}
func TestByUnmarshallingJSONEmptyInput(t *testing.T) {
v := &mocks.T{}
r := mocks.NewResponseWithContent(``)
err := Respond(r,
ByUnmarshallingJSON(v),
ByClosing())
if err != nil {
t.Fatalf("autorest: ByUnmarshallingJSON failed to return nil in case of empty JSON (%v)", err)
}
}
func TestByUnmarshallingXML(t *testing.T) {
v := &mocks.T{}
r := mocks.NewResponseWithContent(xmlT)
err := Respond(r,
ByUnmarshallingXML(v),
ByClosing())
if err != nil {
t.Fatalf("autorest: ByUnmarshallingXML failed (%v)", err)
}
if v.Name != "Rob Pike" || v.Age != 42 {
t.Fatalf("autorest: ByUnmarshallingXML failed to properly unmarshal")
}
}
func TestByUnmarshallingXML_HandlesReadErrors(t *testing.T) {
v := &mocks.T{}
r := mocks.NewResponseWithContent(xmlT)
r.Body.(*mocks.Body).Close()
err := Respond(r,
ByUnmarshallingXML(v),
ByClosing())
if err == nil {
t.Fatalf("autorest: ByUnmarshallingXML failed to receive / respond to read error")
}
}
func TestByUnmarshallingXMLIncludesXMLInErrors(t *testing.T) {
v := &mocks.T{}
x := xmlT[0 : len(xmlT)-2]
r := mocks.NewResponseWithContent(x)
err := Respond(r,
ByUnmarshallingXML(v),
ByClosing())
if err == nil || !strings.Contains(err.Error(), x) {
t.Fatalf("autorest: ByUnmarshallingXML failed to return XML in error (%v)", err)
}
}
func TestRespondAcceptsNullResponse(t *testing.T) {
err := Respond(nil)
if err != nil {
t.Fatalf("autorest: Respond returned an unexpected error when given a null Response (%v)", err)
}
}
func TestWithErrorUnlessStatusCodeOKResponse(t *testing.T) {
v := &mocks.T{}
r := mocks.NewResponseWithContent(jsonT)
err := Respond(r,
WithErrorUnlessStatusCode(http.StatusOK),
ByUnmarshallingJSON(v),
ByClosing())
if err != nil {
t.Fatalf("autorest: WithErrorUnlessStatusCode(http.StatusOK) failed on okay response. (%v)", err)
}
if v.Name != "Rob Pike" || v.Age != 42 {
t.Fatalf("autorest: WithErrorUnlessStatusCode(http.StatusOK) corrupted the response body of okay response.")
}
}
func TesWithErrorUnlessStatusCodeErrorResponse(t *testing.T) {
v := &mocks.T{}
e := &mocks.T{}
r := mocks.NewResponseWithContent(jsonT)
r.Status = "400 BadRequest"
r.StatusCode = http.StatusBadRequest
err := Respond(r,
WithErrorUnlessStatusCode(http.StatusOK),
ByUnmarshallingJSON(v),
ByClosing())
if err == nil {
t.Fatal("autorest: WithErrorUnlessStatusCode(http.StatusOK) did not return error, on a response to a bad request.")
}
var errorRespBody []byte
if derr, ok := err.(DetailedError); !ok {
t.Fatalf("autorest: WithErrorUnlessStatusCode(http.StatusOK) got wrong error type : %T, expected: DetailedError, on a response to a bad request.", err)
} else {
errorRespBody = derr.ServiceError
}
if errorRespBody == nil {
t.Fatalf("autorest: WithErrorUnlessStatusCode(http.StatusOK) ServiceError not returned in DetailedError on a response to a bad request.")
}
err = json.Unmarshal(errorRespBody, e)
if err != nil {
t.Fatalf("autorest: WithErrorUnlessStatusCode(http.StatusOK) cannot parse error returned in ServiceError into json. %v", err)
}
expected := &mocks.T{Name: "Rob Pike", Age: 42}
if e != expected {
t.Fatalf("autorest: WithErrorUnlessStatusCode(http.StatusOK wrong value from parsed ServiceError: got=%#v expected=%#v", e, expected)
}
}
func TestWithErrorUnlessStatusCode(t *testing.T) {
r := mocks.NewResponse()
r.Request = mocks.NewRequest()
r.Status = "400 BadRequest"
r.StatusCode = http.StatusBadRequest
err := Respond(r,
WithErrorUnlessStatusCode(http.StatusBadRequest, http.StatusUnauthorized, http.StatusInternalServerError),
ByClosingIfError())
if err != nil {
t.Fatalf("autorest: WithErrorUnlessStatusCode returned an error (%v) for an acceptable status code (%s)", err, r.Status)
}
}
func TestWithErrorUnlessStatusCodeEmitsErrorForUnacceptableStatusCode(t *testing.T) {
r := mocks.NewResponse()
r.Request = mocks.NewRequest()
r.Status = "400 BadRequest"
r.StatusCode = http.StatusBadRequest
err := Respond(r,
WithErrorUnlessStatusCode(http.StatusOK, http.StatusUnauthorized, http.StatusInternalServerError),
ByClosingIfError())
if err == nil {
t.Fatalf("autorest: WithErrorUnlessStatusCode failed to return an error for an unacceptable status code (%s)", r.Status)
}
}
func TestWithErrorUnlessOK(t *testing.T) {
r := mocks.NewResponse()
r.Request = mocks.NewRequest()
err := Respond(r,
WithErrorUnlessOK(),
ByClosingIfError())
if err != nil {
t.Fatalf("autorest: WithErrorUnlessOK returned an error for OK status code (%v)", err)
}
}
func TestWithErrorUnlessOKEmitsErrorIfNotOK(t *testing.T) {
r := mocks.NewResponse()
r.Request = mocks.NewRequest()
r.Status = "400 BadRequest"
r.StatusCode = http.StatusBadRequest
err := Respond(r,
WithErrorUnlessOK(),
ByClosingIfError())
if err == nil {
t.Fatalf("autorest: WithErrorUnlessOK failed to return an error for a non-OK status code (%v)", err)
}
}
func TestExtractHeader(t *testing.T) {
r := mocks.NewResponse()
v := []string{"v1", "v2", "v3"}
mocks.SetResponseHeaderValues(r, mocks.TestHeader, v)
if !reflect.DeepEqual(ExtractHeader(mocks.TestHeader, r), v) {
t.Fatalf("autorest: ExtractHeader failed to retrieve the expected header -- expected [%s]%v, received [%s]%v",
mocks.TestHeader, v, mocks.TestHeader, ExtractHeader(mocks.TestHeader, r))
}
}
func TestExtractHeaderHandlesMissingHeader(t *testing.T) {
var v []string
r := mocks.NewResponse()
if !reflect.DeepEqual(ExtractHeader(mocks.TestHeader, r), v) {
t.Fatalf("autorest: ExtractHeader failed to handle a missing header -- expected %v, received %v",
v, ExtractHeader(mocks.TestHeader, r))
}
}
func TestExtractHeaderValue(t *testing.T) {
r := mocks.NewResponse()
v := "v1"
mocks.SetResponseHeader(r, mocks.TestHeader, v)
if ExtractHeaderValue(mocks.TestHeader, r) != v {
t.Fatalf("autorest: ExtractHeader failed to retrieve the expected header -- expected [%s]%v, received [%s]%v",
mocks.TestHeader, v, mocks.TestHeader, ExtractHeaderValue(mocks.TestHeader, r))
}
}
func TestExtractHeaderValueHandlesMissingHeader(t *testing.T) {
r := mocks.NewResponse()
v := ""
if ExtractHeaderValue(mocks.TestHeader, r) != v {
t.Fatalf("autorest: ExtractHeader failed to retrieve the expected header -- expected [%s]%v, received [%s]%v",
mocks.TestHeader, v, mocks.TestHeader, ExtractHeaderValue(mocks.TestHeader, r))
}
}
func TestExtractHeaderValueRetrievesFirstValue(t *testing.T) {
r := mocks.NewResponse()
v := []string{"v1", "v2", "v3"}
mocks.SetResponseHeaderValues(r, mocks.TestHeader, v)
if ExtractHeaderValue(mocks.TestHeader, r) != v[0] {
t.Fatalf("autorest: ExtractHeader failed to retrieve the expected header -- expected [%s]%v, received [%s]%v",
mocks.TestHeader, v[0], mocks.TestHeader, ExtractHeaderValue(mocks.TestHeader, r))
}
}

View file

@ -215,20 +215,26 @@ func DoRetryForStatusCodes(attempts int, backoff time.Duration, codes ...int) Se
rr := NewRetriableRequest(r)
// Increment to add the first call (attempts denotes number of retries)
attempts++
for attempt := 0; attempt < attempts; attempt++ {
for attempt := 0; attempt < attempts; {
err = rr.Prepare()
if err != nil {
return resp, err
}
resp, err = s.Do(rr.Request())
// we want to retry if err is not nil (e.g. transient network failure)
if err == nil && !ResponseHasStatusCode(resp, codes...) {
// we want to retry if err is not nil (e.g. transient network failure). note that for failed authentication
// resp and err will both have a value, so in this case we don't want to retry as it will never succeed.
if err == nil && !ResponseHasStatusCode(resp, codes...) || IsTokenRefreshError(err) {
return resp, err
}
delayed := DelayWithRetryAfter(resp, r.Cancel)
if !delayed {
DelayForBackoff(backoff, attempt, r.Cancel)
}
// don't count a 429 against the number of attempts
// so that we continue to retry until it succeeds
if resp == nil || resp.StatusCode != http.StatusTooManyRequests {
attempt++
}
}
return resp, err
})

View file

@ -1,811 +0,0 @@
package autorest
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"bytes"
"fmt"
"log"
"net/http"
"os"
"reflect"
"sync"
"testing"
"time"
"github.com/Azure/go-autorest/autorest/mocks"
)
func ExampleSendWithSender() {
r := mocks.NewResponseWithStatus("202 Accepted", http.StatusAccepted)
mocks.SetAcceptedHeaders(r)
client := mocks.NewSender()
client.AppendAndRepeatResponse(r, 10)
logger := log.New(os.Stdout, "autorest: ", 0)
na := NullAuthorizer{}
req, _ := Prepare(&http.Request{},
AsGet(),
WithBaseURL("https://microsoft.com/a/b/c/"),
na.WithAuthorization())
r, _ = SendWithSender(client, req,
WithLogging(logger),
DoErrorIfStatusCode(http.StatusAccepted),
DoCloseIfError(),
DoRetryForAttempts(5, time.Duration(0)))
Respond(r,
ByDiscardingBody(),
ByClosing())
// Output:
// autorest: Sending GET https://microsoft.com/a/b/c/
// autorest: GET https://microsoft.com/a/b/c/ received 202 Accepted
// autorest: Sending GET https://microsoft.com/a/b/c/
// autorest: GET https://microsoft.com/a/b/c/ received 202 Accepted
// autorest: Sending GET https://microsoft.com/a/b/c/
// autorest: GET https://microsoft.com/a/b/c/ received 202 Accepted
// autorest: Sending GET https://microsoft.com/a/b/c/
// autorest: GET https://microsoft.com/a/b/c/ received 202 Accepted
// autorest: Sending GET https://microsoft.com/a/b/c/
// autorest: GET https://microsoft.com/a/b/c/ received 202 Accepted
}
func ExampleDoRetryForAttempts() {
client := mocks.NewSender()
client.SetAndRepeatError(fmt.Errorf("Faux Error"), 10)
// Retry with backoff -- ensure returned Bodies are closed
r, _ := SendWithSender(client, mocks.NewRequest(),
DoCloseIfError(),
DoRetryForAttempts(5, time.Duration(0)))
Respond(r,
ByDiscardingBody(),
ByClosing())
fmt.Printf("Retry stopped after %d attempts", client.Attempts())
// Output: Retry stopped after 5 attempts
}
func ExampleDoErrorIfStatusCode() {
client := mocks.NewSender()
client.AppendAndRepeatResponse(mocks.NewResponseWithStatus("204 NoContent", http.StatusNoContent), 10)
// Chain decorators to retry the request, up to five times, if the status code is 204
r, _ := SendWithSender(client, mocks.NewRequest(),
DoErrorIfStatusCode(http.StatusNoContent),
DoCloseIfError(),
DoRetryForAttempts(5, time.Duration(0)))
Respond(r,
ByDiscardingBody(),
ByClosing())
fmt.Printf("Retry stopped after %d attempts with code %s", client.Attempts(), r.Status)
// Output: Retry stopped after 5 attempts with code 204 NoContent
}
func TestSendWithSenderRunsDecoratorsInOrder(t *testing.T) {
client := mocks.NewSender()
s := ""
r, err := SendWithSender(client, mocks.NewRequest(),
withMessage(&s, "a"),
withMessage(&s, "b"),
withMessage(&s, "c"))
if err != nil {
t.Fatalf("autorest: SendWithSender returned an error (%v)", err)
}
Respond(r,
ByDiscardingBody(),
ByClosing())
if s != "abc" {
t.Fatalf("autorest: SendWithSender invoke decorators out of order; expected 'abc', received '%s'", s)
}
}
func TestCreateSender(t *testing.T) {
f := false
s := CreateSender(
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
f = true
return nil, nil
})
}
})())
s.Do(&http.Request{})
if !f {
t.Fatal("autorest: CreateSender failed to apply supplied decorator")
}
}
func TestSend(t *testing.T) {
f := false
Send(&http.Request{},
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
f = true
return nil, nil
})
}
})())
if !f {
t.Fatal("autorest: Send failed to apply supplied decorator")
}
}
func TestAfterDelayWaits(t *testing.T) {
client := mocks.NewSender()
d := 2 * time.Second
tt := time.Now()
r, _ := SendWithSender(client, mocks.NewRequest(),
AfterDelay(d))
s := time.Since(tt)
if s < d {
t.Fatal("autorest: AfterDelay failed to wait for at least the specified duration")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestAfterDelay_Cancels(t *testing.T) {
client := mocks.NewSender()
cancel := make(chan struct{})
delay := 5 * time.Second
var wg sync.WaitGroup
wg.Add(1)
tt := time.Now()
go func() {
req := mocks.NewRequest()
req.Cancel = cancel
wg.Done()
SendWithSender(client, req,
AfterDelay(delay))
}()
wg.Wait()
close(cancel)
time.Sleep(5 * time.Millisecond)
if time.Since(tt) >= delay {
t.Fatal("autorest: AfterDelay failed to cancel")
}
}
func TestAfterDelayDoesNotWaitTooLong(t *testing.T) {
client := mocks.NewSender()
d := 5 * time.Millisecond
start := time.Now()
r, _ := SendWithSender(client, mocks.NewRequest(),
AfterDelay(d))
if time.Since(start) > (5 * d) {
t.Fatal("autorest: AfterDelay waited too long (exceeded 5 times specified duration)")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestAsIs(t *testing.T) {
client := mocks.NewSender()
r1 := mocks.NewResponse()
client.AppendResponse(r1)
r2, err := SendWithSender(client, mocks.NewRequest(),
AsIs())
if err != nil {
t.Fatalf("autorest: AsIs returned an unexpected error (%v)", err)
} else if !reflect.DeepEqual(r1, r2) {
t.Fatalf("autorest: AsIs modified the response -- received %v, expected %v", r2, r1)
}
Respond(r1,
ByDiscardingBody(),
ByClosing())
Respond(r2,
ByDiscardingBody(),
ByClosing())
}
func TestDoCloseIfError(t *testing.T) {
client := mocks.NewSender()
client.AppendResponse(mocks.NewResponseWithStatus("400 BadRequest", http.StatusBadRequest))
r, _ := SendWithSender(client, mocks.NewRequest(),
DoErrorIfStatusCode(http.StatusBadRequest),
DoCloseIfError())
if r.Body.(*mocks.Body).IsOpen() {
t.Fatal("autorest: Expected DoCloseIfError to close response body -- it was left open")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestDoCloseIfErrorAcceptsNilResponse(t *testing.T) {
client := mocks.NewSender()
SendWithSender(client, mocks.NewRequest(),
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
resp, err := s.Do(r)
if err != nil {
resp.Body.Close()
}
return nil, fmt.Errorf("Faux Error")
})
}
})(),
DoCloseIfError())
}
func TestDoCloseIfErrorAcceptsNilBody(t *testing.T) {
client := mocks.NewSender()
SendWithSender(client, mocks.NewRequest(),
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
resp, err := s.Do(r)
if err != nil {
resp.Body.Close()
}
resp.Body = nil
return resp, fmt.Errorf("Faux Error")
})
}
})(),
DoCloseIfError())
}
func TestDoErrorIfStatusCode(t *testing.T) {
client := mocks.NewSender()
client.AppendResponse(mocks.NewResponseWithStatus("400 BadRequest", http.StatusBadRequest))
r, err := SendWithSender(client, mocks.NewRequest(),
DoErrorIfStatusCode(http.StatusBadRequest),
DoCloseIfError())
if err == nil {
t.Fatal("autorest: DoErrorIfStatusCode failed to emit an error for passed code")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestDoErrorIfStatusCodeIgnoresStatusCodes(t *testing.T) {
client := mocks.NewSender()
client.AppendResponse(newAcceptedResponse())
r, err := SendWithSender(client, mocks.NewRequest(),
DoErrorIfStatusCode(http.StatusBadRequest),
DoCloseIfError())
if err != nil {
t.Fatal("autorest: DoErrorIfStatusCode failed to ignore a status code")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestDoErrorUnlessStatusCode(t *testing.T) {
client := mocks.NewSender()
client.AppendResponse(mocks.NewResponseWithStatus("400 BadRequest", http.StatusBadRequest))
r, err := SendWithSender(client, mocks.NewRequest(),
DoErrorUnlessStatusCode(http.StatusAccepted),
DoCloseIfError())
if err == nil {
t.Fatal("autorest: DoErrorUnlessStatusCode failed to emit an error for an unknown status code")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestDoErrorUnlessStatusCodeIgnoresStatusCodes(t *testing.T) {
client := mocks.NewSender()
client.AppendResponse(newAcceptedResponse())
r, err := SendWithSender(client, mocks.NewRequest(),
DoErrorUnlessStatusCode(http.StatusAccepted),
DoCloseIfError())
if err != nil {
t.Fatal("autorest: DoErrorUnlessStatusCode emitted an error for a knonwn status code")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestDoRetryForAttemptsStopsAfterSuccess(t *testing.T) {
client := mocks.NewSender()
r, err := SendWithSender(client, mocks.NewRequest(),
DoRetryForAttempts(5, time.Duration(0)))
if client.Attempts() != 1 {
t.Fatalf("autorest: DoRetryForAttempts failed to stop after success -- expected attempts %v, actual %v",
1, client.Attempts())
}
if err != nil {
t.Fatalf("autorest: DoRetryForAttempts returned an unexpected error (%v)", err)
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestDoRetryForAttemptsStopsAfterAttempts(t *testing.T) {
client := mocks.NewSender()
client.SetAndRepeatError(fmt.Errorf("Faux Error"), 10)
r, err := SendWithSender(client, mocks.NewRequest(),
DoRetryForAttempts(5, time.Duration(0)),
DoCloseIfError())
if err == nil {
t.Fatal("autorest: Mock client failed to emit errors")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
if client.Attempts() != 5 {
t.Fatal("autorest: DoRetryForAttempts failed to stop after specified number of attempts")
}
}
func TestDoRetryForAttemptsReturnsResponse(t *testing.T) {
client := mocks.NewSender()
client.SetError(fmt.Errorf("Faux Error"))
r, err := SendWithSender(client, mocks.NewRequest(),
DoRetryForAttempts(1, time.Duration(0)))
if err == nil {
t.Fatal("autorest: Mock client failed to emit errors")
}
if r == nil {
t.Fatal("autorest: DoRetryForAttempts failed to return the underlying response")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestDoRetryForDurationStopsAfterSuccess(t *testing.T) {
client := mocks.NewSender()
r, err := SendWithSender(client, mocks.NewRequest(),
DoRetryForDuration(10*time.Millisecond, time.Duration(0)))
if client.Attempts() != 1 {
t.Fatalf("autorest: DoRetryForDuration failed to stop after success -- expected attempts %v, actual %v",
1, client.Attempts())
}
if err != nil {
t.Fatalf("autorest: DoRetryForDuration returned an unexpected error (%v)", err)
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestDoRetryForDurationStopsAfterDuration(t *testing.T) {
client := mocks.NewSender()
client.SetAndRepeatError(fmt.Errorf("Faux Error"), -1)
d := 5 * time.Millisecond
start := time.Now()
r, err := SendWithSender(client, mocks.NewRequest(),
DoRetryForDuration(d, time.Duration(0)),
DoCloseIfError())
if err == nil {
t.Fatal("autorest: Mock client failed to emit errors")
}
if time.Since(start) < d {
t.Fatal("autorest: DoRetryForDuration failed stopped too soon")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestDoRetryForDurationStopsWithinReason(t *testing.T) {
client := mocks.NewSender()
client.SetAndRepeatError(fmt.Errorf("Faux Error"), -1)
d := 5 * time.Second
start := time.Now()
r, err := SendWithSender(client, mocks.NewRequest(),
DoRetryForDuration(d, time.Duration(0)),
DoCloseIfError())
if err == nil {
t.Fatal("autorest: Mock client failed to emit errors")
}
if time.Since(start) > (5 * d) {
t.Fatal("autorest: DoRetryForDuration failed stopped soon enough (exceeded 5 times specified duration)")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestDoRetryForDurationReturnsResponse(t *testing.T) {
client := mocks.NewSender()
client.SetAndRepeatError(fmt.Errorf("Faux Error"), -1)
r, err := SendWithSender(client, mocks.NewRequest(),
DoRetryForDuration(10*time.Millisecond, time.Duration(0)),
DoCloseIfError())
if err == nil {
t.Fatal("autorest: Mock client failed to emit errors")
}
if r == nil {
t.Fatal("autorest: DoRetryForDuration failed to return the underlying response")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestDelayForBackoff(t *testing.T) {
d := 2 * time.Second
start := time.Now()
DelayForBackoff(d, 0, nil)
if time.Since(start) < d {
t.Fatal("autorest: DelayForBackoff did not delay as long as expected")
}
}
func TestDelayForBackoff_Cancels(t *testing.T) {
cancel := make(chan struct{})
delay := 5 * time.Second
var wg sync.WaitGroup
wg.Add(1)
start := time.Now()
go func() {
wg.Done()
DelayForBackoff(delay, 0, cancel)
}()
wg.Wait()
close(cancel)
time.Sleep(5 * time.Millisecond)
if time.Since(start) >= delay {
t.Fatal("autorest: DelayForBackoff failed to cancel")
}
}
func TestDelayForBackoffWithinReason(t *testing.T) {
d := 5 * time.Second
maxCoefficient := 2
start := time.Now()
DelayForBackoff(d, 0, nil)
if time.Since(start) > (time.Duration(maxCoefficient) * d) {
t.Fatalf("autorest: DelayForBackoff delayed too long (exceeded %d times the specified duration)", maxCoefficient)
}
}
func TestDoPollForStatusCodes_IgnoresUnspecifiedStatusCodes(t *testing.T) {
client := mocks.NewSender()
r, _ := SendWithSender(client, mocks.NewRequest(),
DoPollForStatusCodes(time.Duration(0), time.Duration(0)))
if client.Attempts() != 1 {
t.Fatalf("autorest: Sender#DoPollForStatusCodes polled for unspecified status code")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestDoPollForStatusCodes_PollsForSpecifiedStatusCodes(t *testing.T) {
client := mocks.NewSender()
client.AppendResponse(newAcceptedResponse())
r, _ := SendWithSender(client, mocks.NewRequest(),
DoPollForStatusCodes(time.Millisecond, time.Millisecond, http.StatusAccepted))
if client.Attempts() != 2 {
t.Fatalf("autorest: Sender#DoPollForStatusCodes failed to poll for specified status code")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestDoPollForStatusCodes_CanBeCanceled(t *testing.T) {
cancel := make(chan struct{})
delay := 5 * time.Second
r := mocks.NewResponse()
mocks.SetAcceptedHeaders(r)
client := mocks.NewSender()
client.AppendAndRepeatResponse(r, 100)
var wg sync.WaitGroup
wg.Add(1)
start := time.Now()
go func() {
wg.Done()
r, _ := SendWithSender(client, mocks.NewRequest(),
DoPollForStatusCodes(time.Millisecond, time.Millisecond, http.StatusAccepted))
Respond(r,
ByDiscardingBody(),
ByClosing())
}()
wg.Wait()
close(cancel)
time.Sleep(5 * time.Millisecond)
if time.Since(start) >= delay {
t.Fatalf("autorest: Sender#DoPollForStatusCodes failed to cancel")
}
}
func TestDoPollForStatusCodes_ClosesAllNonreturnedResponseBodiesWhenPolling(t *testing.T) {
resp := newAcceptedResponse()
client := mocks.NewSender()
client.AppendAndRepeatResponse(resp, 2)
r, _ := SendWithSender(client, mocks.NewRequest(),
DoPollForStatusCodes(time.Millisecond, time.Millisecond, http.StatusAccepted))
if resp.Body.(*mocks.Body).IsOpen() || resp.Body.(*mocks.Body).CloseAttempts() < 2 {
t.Fatalf("autorest: Sender#DoPollForStatusCodes did not close unreturned response bodies")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestDoPollForStatusCodes_LeavesLastResponseBodyOpen(t *testing.T) {
client := mocks.NewSender()
client.AppendResponse(newAcceptedResponse())
r, _ := SendWithSender(client, mocks.NewRequest(),
DoPollForStatusCodes(time.Millisecond, time.Millisecond, http.StatusAccepted))
if !r.Body.(*mocks.Body).IsOpen() {
t.Fatalf("autorest: Sender#DoPollForStatusCodes did not leave open the body of the last response")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestDoPollForStatusCodes_StopsPollingAfterAnError(t *testing.T) {
client := mocks.NewSender()
client.AppendAndRepeatResponse(newAcceptedResponse(), 5)
client.SetError(fmt.Errorf("Faux Error"))
client.SetEmitErrorAfter(1)
r, _ := SendWithSender(client, mocks.NewRequest(),
DoPollForStatusCodes(time.Millisecond, time.Millisecond, http.StatusAccepted))
if client.Attempts() > 2 {
t.Fatalf("autorest: Sender#DoPollForStatusCodes failed to stop polling after receiving an error")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestDoPollForStatusCodes_ReturnsPollingError(t *testing.T) {
client := mocks.NewSender()
client.AppendAndRepeatResponse(newAcceptedResponse(), 5)
client.SetError(fmt.Errorf("Faux Error"))
client.SetEmitErrorAfter(1)
r, err := SendWithSender(client, mocks.NewRequest(),
DoPollForStatusCodes(time.Millisecond, time.Millisecond, http.StatusAccepted))
if err == nil {
t.Fatalf("autorest: Sender#DoPollForStatusCodes failed to return error from polling")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestWithLogging_Logs(t *testing.T) {
buf := &bytes.Buffer{}
logger := log.New(buf, "autorest: ", 0)
client := mocks.NewSender()
r, _ := SendWithSender(client, &http.Request{},
WithLogging(logger))
if buf.String() == "" {
t.Fatal("autorest: Sender#WithLogging failed to log the request")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestWithLogging_HandlesMissingResponse(t *testing.T) {
buf := &bytes.Buffer{}
logger := log.New(buf, "autorest: ", 0)
client := mocks.NewSender()
client.AppendResponse(nil)
client.SetError(fmt.Errorf("Faux Error"))
r, err := SendWithSender(client, &http.Request{},
WithLogging(logger))
if r != nil || err == nil {
t.Fatal("autorest: Sender#WithLogging returned a valid response -- expecting nil")
}
if buf.String() == "" {
t.Fatal("autorest: Sender#WithLogging failed to log the request for a nil response")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestDoRetryForStatusCodesWithSuccess(t *testing.T) {
client := mocks.NewSender()
client.AppendAndRepeatResponse(mocks.NewResponseWithStatus("408 Request Timeout", http.StatusRequestTimeout), 2)
client.AppendResponse(mocks.NewResponseWithStatus("200 OK", http.StatusOK))
r, _ := SendWithSender(client, mocks.NewRequest(),
DoRetryForStatusCodes(5, time.Duration(2*time.Second), http.StatusRequestTimeout),
)
Respond(r,
ByDiscardingBody(),
ByClosing())
if client.Attempts() != 3 {
t.Fatalf("autorest: Sender#DoRetryForStatusCodes -- Got: StatusCode %v in %v attempts; Want: StatusCode 200 OK in 2 attempts -- ",
r.Status, client.Attempts()-1)
}
}
func TestDoRetryForStatusCodesWithNoSuccess(t *testing.T) {
client := mocks.NewSender()
client.AppendAndRepeatResponse(mocks.NewResponseWithStatus("504 Gateway Timeout", http.StatusGatewayTimeout), 5)
r, _ := SendWithSender(client, mocks.NewRequest(),
DoRetryForStatusCodes(2, time.Duration(2*time.Second), http.StatusGatewayTimeout),
)
Respond(r,
ByDiscardingBody(),
ByClosing())
if client.Attempts() != 3 {
t.Fatalf("autorest: Sender#DoRetryForStatusCodes -- Got: failed stop after %v retry attempts; Want: Stop after 2 retry attempts",
client.Attempts()-1)
}
}
func TestDoRetryForStatusCodes_CodeNotInRetryList(t *testing.T) {
client := mocks.NewSender()
client.AppendAndRepeatResponse(mocks.NewResponseWithStatus("204 No Content", http.StatusNoContent), 1)
r, _ := SendWithSender(client, mocks.NewRequest(),
DoRetryForStatusCodes(6, time.Duration(2*time.Second), http.StatusGatewayTimeout),
)
Respond(r,
ByDiscardingBody(),
ByClosing())
if client.Attempts() != 1 || r.Status != "204 No Content" {
t.Fatalf("autorest: Sender#DoRetryForStatusCodes -- Got: Retry attempts %v for StatusCode %v; Want: 0 attempts for StatusCode 204",
client.Attempts(), r.Status)
}
}
func TestDoRetryForStatusCodes_RequestBodyReadError(t *testing.T) {
client := mocks.NewSender()
client.AppendAndRepeatResponse(mocks.NewResponseWithStatus("204 No Content", http.StatusNoContent), 2)
r, err := SendWithSender(client, mocks.NewRequestWithCloseBody(),
DoRetryForStatusCodes(6, time.Duration(2*time.Second), http.StatusGatewayTimeout),
)
Respond(r,
ByDiscardingBody(),
ByClosing())
if err == nil || client.Attempts() != 0 {
t.Fatalf("autorest: Sender#DoRetryForStatusCodes -- Got: Not failed for request body read error; Want: Failed for body read error - %v", err)
}
}
func newAcceptedResponse() *http.Response {
resp := mocks.NewResponseWithStatus("202 Accepted", http.StatusAccepted)
mocks.SetAcceptedHeaders(resp)
return resp
}
func TestDelayWithRetryAfterWithSuccess(t *testing.T) {
after, retries := 5, 2
totalSecs := after * retries
client := mocks.NewSender()
resp := mocks.NewResponseWithStatus("429 Too many requests", http.StatusTooManyRequests)
mocks.SetResponseHeader(resp, "Retry-After", fmt.Sprintf("%v", after))
client.AppendAndRepeatResponse(resp, retries)
client.AppendResponse(mocks.NewResponseWithStatus("200 OK", http.StatusOK))
d := time.Second * time.Duration(totalSecs)
start := time.Now()
r, _ := SendWithSender(client, mocks.NewRequest(),
DoRetryForStatusCodes(5, time.Duration(time.Second), http.StatusTooManyRequests),
)
if time.Since(start) < d {
t.Fatal("autorest: DelayWithRetryAfter failed stopped too soon")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
if client.Attempts() != 3 {
t.Fatalf("autorest: Sender#DelayWithRetryAfter -- Got: StatusCode %v in %v attempts; Want: StatusCode 200 OK in 2 attempts -- ",
r.Status, client.Attempts()-1)
}
}

View file

@ -23,8 +23,9 @@ import (
"net/http"
"net/url"
"reflect"
"sort"
"strings"
"github.com/Azure/go-autorest/autorest/adal"
)
// EncodedAs is a series of constants specifying various data encodings
@ -138,13 +139,38 @@ func MapToValues(m map[string]interface{}) url.Values {
return v
}
// String method converts interface v to string. If interface is a list, it
// joins list elements using separator.
func String(v interface{}, sep ...string) string {
if len(sep) > 0 {
return ensureValueString(strings.Join(v.([]string), sep[0]))
// AsStringSlice method converts interface{} to []string. This expects a
//that the parameter passed to be a slice or array of a type that has the underlying
//type a string.
func AsStringSlice(s interface{}) ([]string, error) {
v := reflect.ValueOf(s)
if v.Kind() != reflect.Slice && v.Kind() != reflect.Array {
return nil, NewError("autorest", "AsStringSlice", "the value's type is not an array.")
}
return ensureValueString(v)
stringSlice := make([]string, 0, v.Len())
for i := 0; i < v.Len(); i++ {
stringSlice = append(stringSlice, v.Index(i).String())
}
return stringSlice, nil
}
// String method converts interface v to string. If interface is a list, it
// joins list elements using the seperator. Note that only sep[0] will be used for
// joining if any separator is specified.
func String(v interface{}, sep ...string) string {
if len(sep) == 0 {
return ensureValueString(v)
}
stringSlice, ok := v.([]string)
if ok == false {
var err error
stringSlice, err = AsStringSlice(v)
if err != nil {
panic(fmt.Sprintf("autorest: Couldn't convert value to a string %s.", err))
}
}
return ensureValueString(strings.Join(stringSlice, sep[0]))
}
// Encode method encodes url path and query parameters.
@ -168,30 +194,6 @@ func queryEscape(s string) string {
return url.QueryEscape(s)
}
// This method is same as Encode() method of "net/url" go package,
// except it does not encode the query parameters because they
// already come encoded. It formats values map in query format (bar=foo&a=b).
func createQuery(v url.Values) string {
var buf bytes.Buffer
keys := make([]string, 0, len(v))
for k := range v {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
vs := v[k]
prefix := url.QueryEscape(k) + "="
for _, v := range vs {
if buf.Len() > 0 {
buf.WriteByte('&')
}
buf.WriteString(prefix)
buf.WriteString(v)
}
}
return buf.String()
}
// ChangeToGet turns the specified http.Request into a GET (it assumes it wasn't).
// This is mainly useful for long-running operations that use the Azure-AsyncOperation
// header, so we change the initial PUT into a GET to retrieve the final result.
@ -202,3 +204,15 @@ func ChangeToGet(req *http.Request) *http.Request {
req.Header.Del("Content-Length")
return req
}
// IsTokenRefreshError returns true if the specified error implements the TokenRefreshError
// interface. If err is a DetailedError it will walk the chain of Original errors.
func IsTokenRefreshError(err error) bool {
if _, ok := err.(adal.TokenRefreshError); ok {
return true
}
if de, ok := err.(DetailedError); ok {
return IsTokenRefreshError(de.Original)
}
return false
}

View file

@ -1,382 +0,0 @@
package autorest
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"bytes"
"encoding/json"
"encoding/xml"
"fmt"
"net/http"
"net/url"
"reflect"
"sort"
"strings"
"testing"
"github.com/Azure/go-autorest/autorest/mocks"
)
const (
jsonT = `
{
"name":"Rob Pike",
"age":42
}`
xmlT = `<?xml version="1.0" encoding="UTF-8"?>
<Person>
<Name>Rob Pike</Name>
<Age>42</Age>
</Person>`
)
func TestNewDecoderCreatesJSONDecoder(t *testing.T) {
d := NewDecoder(EncodedAsJSON, strings.NewReader(jsonT))
_, ok := d.(*json.Decoder)
if d == nil || !ok {
t.Fatal("autorest: NewDecoder failed to create a JSON decoder when requested")
}
}
func TestNewDecoderCreatesXMLDecoder(t *testing.T) {
d := NewDecoder(EncodedAsXML, strings.NewReader(xmlT))
_, ok := d.(*xml.Decoder)
if d == nil || !ok {
t.Fatal("autorest: NewDecoder failed to create an XML decoder when requested")
}
}
func TestNewDecoderReturnsNilForUnknownEncoding(t *testing.T) {
d := NewDecoder("unknown", strings.NewReader(xmlT))
if d != nil {
t.Fatal("autorest: NewDecoder created a decoder for an unknown encoding")
}
}
func TestCopyAndDecodeDecodesJSON(t *testing.T) {
_, err := CopyAndDecode(EncodedAsJSON, strings.NewReader(jsonT), &mocks.T{})
if err != nil {
t.Fatalf("autorest: CopyAndDecode returned an error with valid JSON - %v", err)
}
}
func TestCopyAndDecodeDecodesXML(t *testing.T) {
_, err := CopyAndDecode(EncodedAsXML, strings.NewReader(xmlT), &mocks.T{})
if err != nil {
t.Fatalf("autorest: CopyAndDecode returned an error with valid XML - %v", err)
}
}
func TestCopyAndDecodeReturnsJSONDecodingErrors(t *testing.T) {
_, err := CopyAndDecode(EncodedAsJSON, strings.NewReader(jsonT[0:len(jsonT)-2]), &mocks.T{})
if err == nil {
t.Fatalf("autorest: CopyAndDecode failed to return an error with invalid JSON")
}
}
func TestCopyAndDecodeReturnsXMLDecodingErrors(t *testing.T) {
_, err := CopyAndDecode(EncodedAsXML, strings.NewReader(xmlT[0:len(xmlT)-2]), &mocks.T{})
if err == nil {
t.Fatalf("autorest: CopyAndDecode failed to return an error with invalid XML")
}
}
func TestCopyAndDecodeAlwaysReturnsACopy(t *testing.T) {
b, _ := CopyAndDecode(EncodedAsJSON, strings.NewReader(jsonT), &mocks.T{})
if b.String() != jsonT {
t.Fatalf("autorest: CopyAndDecode failed to return a valid copy of the data - %v", b.String())
}
}
func TestTeeReadCloser_Copies(t *testing.T) {
v := &mocks.T{}
r := mocks.NewResponseWithContent(jsonT)
b := &bytes.Buffer{}
r.Body = TeeReadCloser(r.Body, b)
err := Respond(r,
ByUnmarshallingJSON(v),
ByClosing())
if err != nil {
t.Fatalf("autorest: TeeReadCloser returned an unexpected error -- %v", err)
}
if b.String() != jsonT {
t.Fatalf("autorest: TeeReadCloser failed to copy the bytes read")
}
}
func TestTeeReadCloser_PassesReadErrors(t *testing.T) {
v := &mocks.T{}
r := mocks.NewResponseWithContent(jsonT)
r.Body.(*mocks.Body).Close()
r.Body = TeeReadCloser(r.Body, &bytes.Buffer{})
err := Respond(r,
ByUnmarshallingJSON(v),
ByClosing())
if err == nil {
t.Fatalf("autorest: TeeReadCloser failed to return the expected error")
}
}
func TestTeeReadCloser_ClosesWrappedReader(t *testing.T) {
v := &mocks.T{}
r := mocks.NewResponseWithContent(jsonT)
b := r.Body.(*mocks.Body)
r.Body = TeeReadCloser(r.Body, &bytes.Buffer{})
err := Respond(r,
ByUnmarshallingJSON(v),
ByClosing())
if err != nil {
t.Fatalf("autorest: TeeReadCloser returned an unexpected error -- %v", err)
}
if b.IsOpen() {
t.Fatalf("autorest: TeeReadCloser failed to close the nested io.ReadCloser")
}
}
func TestContainsIntFindsValue(t *testing.T) {
ints := []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
v := 5
if !containsInt(ints, v) {
t.Fatalf("autorest: containsInt failed to find %v in %v", v, ints)
}
}
func TestContainsIntDoesNotFindValue(t *testing.T) {
ints := []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
v := 42
if containsInt(ints, v) {
t.Fatalf("autorest: containsInt unexpectedly found %v in %v", v, ints)
}
}
func TestContainsIntAcceptsEmptyList(t *testing.T) {
ints := make([]int, 10)
if containsInt(ints, 42) {
t.Fatalf("autorest: containsInt failed to handle an empty list")
}
}
func TestContainsIntAcceptsNilList(t *testing.T) {
var ints []int
if containsInt(ints, 42) {
t.Fatalf("autorest: containsInt failed to handle an nil list")
}
}
func TestEscapeStrings(t *testing.T) {
m := map[string]string{
"string": "a long string with = odd characters",
"int": "42",
"nil": "",
}
r := map[string]string{
"string": "a+long+string+with+%3D+odd+characters",
"int": "42",
"nil": "",
}
v := escapeValueStrings(m)
if !reflect.DeepEqual(v, r) {
t.Fatalf("autorest: ensureValueStrings returned %v\n", v)
}
}
func TestEnsureStrings(t *testing.T) {
m := map[string]interface{}{
"string": "string",
"int": 42,
"nil": nil,
"bytes": []byte{255, 254, 253},
}
r := map[string]string{
"string": "string",
"int": "42",
"nil": "",
"bytes": string([]byte{255, 254, 253}),
}
v := ensureValueStrings(m)
if !reflect.DeepEqual(v, r) {
t.Fatalf("autorest: ensureValueStrings returned %v\n", v)
}
}
func ExampleString() {
m := []string{
"string1",
"string2",
"string3",
}
fmt.Println(String(m, ","))
// Output: string1,string2,string3
}
func TestStringWithValidString(t *testing.T) {
i := 123
if String(i) != "123" {
t.Fatal("autorest: String method failed to convert integer 123 to string")
}
}
func TestEncodeWithValidPath(t *testing.T) {
s := Encode("Path", "Hello Gopher")
if s != "Hello%20Gopher" {
t.Fatalf("autorest: Encode method failed for valid path encoding. Got: %v; Want: %v", s, "Hello%20Gopher")
}
}
func TestEncodeWithValidQuery(t *testing.T) {
s := Encode("Query", "Hello Gopher")
if s != "Hello+Gopher" {
t.Fatalf("autorest: Encode method failed for valid query encoding. Got: '%v'; Want: 'Hello+Gopher'", s)
}
}
func TestEncodeWithValidNotPathQuery(t *testing.T) {
s := Encode("Host", "Hello Gopher")
if s != "Hello Gopher" {
t.Fatalf("autorest: Encode method failed for parameter not query or path. Got: '%v'; Want: 'Hello Gopher'", s)
}
}
func TestMapToValues(t *testing.T) {
m := map[string]interface{}{
"a": "a",
"b": 2,
}
v := url.Values{}
v.Add("a", "a")
v.Add("b", "2")
if !isEqual(v, MapToValues(m)) {
t.Fatalf("autorest: MapToValues method failed to return correct values - expected(%v) got(%v)", v, MapToValues(m))
}
}
func TestMapToValuesWithArrayValues(t *testing.T) {
m := map[string]interface{}{
"a": []string{"a", "b"},
"b": 2,
"c": []int{3, 4},
}
v := url.Values{}
v.Add("a", "a")
v.Add("a", "b")
v.Add("b", "2")
v.Add("c", "3")
v.Add("c", "4")
if !isEqual(v, MapToValues(m)) {
t.Fatalf("autorest: MapToValues method failed to return correct values - expected(%v) got(%v)", v, MapToValues(m))
}
}
func isEqual(v, u url.Values) bool {
for key, value := range v {
if len(u[key]) == 0 {
return false
}
sort.Strings(value)
sort.Strings(u[key])
for i := range value {
if value[i] != u[key][i] {
return false
}
}
u.Del(key)
}
if len(u) > 0 {
return false
}
return true
}
func doEnsureBodyClosed(t *testing.T) SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
resp, err := s.Do(r)
if resp != nil && resp.Body != nil && resp.Body.(*mocks.Body).IsOpen() {
t.Fatal("autorest: Expected Body to be closed -- it was left open")
}
return resp, err
})
}
}
type mockAuthorizer struct{}
func (ma mockAuthorizer) WithAuthorization() PrepareDecorator {
return WithHeader(headerAuthorization, mocks.TestAuthorizationHeader)
}
type mockFailingAuthorizer struct{}
func (mfa mockFailingAuthorizer) WithAuthorization() PrepareDecorator {
return func(p Preparer) Preparer {
return PreparerFunc(func(r *http.Request) (*http.Request, error) {
return r, fmt.Errorf("ERROR: mockFailingAuthorizer returned expected error")
})
}
}
type mockInspector struct {
wasInvoked bool
}
func (mi *mockInspector) WithInspection() PrepareDecorator {
return func(p Preparer) Preparer {
return PreparerFunc(func(r *http.Request) (*http.Request, error) {
mi.wasInvoked = true
return p.Prepare(r)
})
}
}
func (mi *mockInspector) ByInspecting() RespondDecorator {
return func(r Responder) Responder {
return ResponderFunc(func(resp *http.Response) error {
mi.wasInvoked = true
return r.Respond(resp)
})
}
}
func withMessage(output *string, msg string) SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
resp, err := s.Do(r)
if err == nil {
*output += msg
}
return resp, err
})
}
}
func withErrorRespondDecorator(e *error) RespondDecorator {
return func(r Responder) Responder {
return ResponderFunc(func(resp *http.Response) error {
err := r.Respond(resp)
if err != nil {
return err
}
*e = fmt.Errorf("autorest: Faux Respond Error")
return *e
})
}
}

View file

@ -22,9 +22,9 @@ import (
)
const (
major = 8
minor = 0
patch = 0
major = 9
minor = 8
patch = 1
tag = ""
)