generalize cidr parsing and improve lua tests

This commit is contained in:
Elvin Efendi 2021-01-04 15:01:55 -05:00
parent 2254a91866
commit 2cff9fa41d
7 changed files with 68 additions and 70 deletions

View file

@ -19,7 +19,6 @@ package ratelimit
import (
"encoding/base64"
"fmt"
"sort"
"strings"
networking "k8s.io/api/networking/v1beta1"
@ -164,7 +163,7 @@ func (a ratelimit) Parse(ing *networking.Ingress) (interface{}, error) {
val, _ := parser.GetStringAnnotation("limit-whitelist", ing)
cidrs, err := parseCIDRs(val)
cidrs, err := net.ParseCIDRs(val)
if err != nil {
return nil, err
}
@ -208,32 +207,6 @@ func (a ratelimit) Parse(ing *networking.Ingress) (interface{}, error) {
}, nil
}
func parseCIDRs(s string) ([]string, error) {
if s == "" {
return []string{}, nil
}
values := strings.Split(s, ",")
ipnets, ips, err := net.ParseIPNets(values...)
if err != nil {
return nil, err
}
cidrs := []string{}
for k := range ipnets {
cidrs = append(cidrs, k)
}
for k := range ips {
cidrs = append(cidrs, k)
}
sort.Strings(cidrs)
return cidrs, nil
}
func encode(s string) string {
str := base64.URLEncoding.EncodeToString([]byte(s))
return strings.Replace(str, "=", "", -1)

View file

@ -17,8 +17,6 @@ limitations under the License.
package ratelimit
import (
"reflect"
"sort"
"testing"
api "k8s.io/api/core/v1"
@ -85,23 +83,6 @@ func TestWithoutAnnotations(t *testing.T) {
}
}
func TestParseCIDRs(t *testing.T) {
cidr, _ := parseCIDRs("invalid.com")
if cidr != nil {
t.Errorf("expected %v but got %v", nil, cidr)
}
expected := []string{"192.0.0.1", "192.0.1.0/24"}
cidr, err := parseCIDRs("192.0.0.1, 192.0.1.0/24")
if err != nil {
t.Errorf("unexpected error %v", err)
}
sort.Strings(cidr)
if !reflect.DeepEqual(expected, cidr) {
t.Errorf("expected %v but got %v", expected, cidr)
}
}
func TestRateLimiting(t *testing.T) {
ing := buildIngress()

View file

@ -18,6 +18,7 @@ package net
import (
"net"
"sort"
"strings"
)
@ -51,3 +52,30 @@ func ParseIPNets(specs ...string) (IPNet, IP, error) {
return ipnetset, ipset, nil
}
// ParseCIDRs parses comma separated CIDRs into a sorted string array
func ParseCIDRs(s string) ([]string, error) {
if s == "" {
return []string{}, nil
}
values := strings.Split(s, ",")
ipnets, ips, err := ParseIPNets(values...)
if err != nil {
return nil, err
}
cidrs := []string{}
for k := range ipnets {
cidrs = append(cidrs, k)
}
for k := range ips {
cidrs = append(cidrs, k)
}
sort.Strings(cidrs)
return cidrs, nil
}

View file

@ -17,6 +17,8 @@ limitations under the License.
package net
import (
"reflect"
"sort"
"testing"
)
@ -32,3 +34,20 @@ func TestNewIPSet(t *testing.T) {
t.Errorf("Expected len=1: %d", len(ips))
}
}
func TestParseCIDRs(t *testing.T) {
cidr, _ := ParseCIDRs("invalid.com")
if cidr != nil {
t.Errorf("expected %v but got %v", nil, cidr)
}
expected := []string{"192.0.0.1", "192.0.1.0/24"}
cidr, err := ParseCIDRs("192.0.0.1, 192.0.1.0/24")
if err != nil {
t.Errorf("unexpected error %v", err)
}
sort.Strings(cidr)
if !reflect.DeepEqual(expected, cidr) {
t.Errorf("expected %v but got %v", expected, cidr)
}
}