diff --git a/go.mod b/go.mod index dd1ebb457..db407b5fe 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/mitchellh/go-ps v0.0.0-20170309133038-4fdf99ab2936 github.com/onsi/ginkgo/v2 v2.9.0 github.com/onsi/gomega v1.27.1 + golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 google.golang.org/grpc v1.53.0 k8s.io/api v0.25.6 k8s.io/apimachinery v0.25.6 @@ -81,7 +82,7 @@ require ( golang.org/x/term v0.11.0 // indirect golang.org/x/text v0.12.0 // indirect golang.org/x/time v0.0.0-20220609170525-579cf78fd858 // indirect - golang.org/x/tools v0.6.0 // indirect + golang.org/x/tools v0.12.1-0.20230815132531-74c255bcf846 // indirect google.golang.org/appengine v1.6.7 // indirect google.golang.org/genproto v0.0.0-20230110181048-76db0878b65f // indirect google.golang.org/protobuf v1.28.1 // indirect diff --git a/go.sum b/go.sum index ee54fc300..3e614c819 100644 --- a/go.sum +++ b/go.sum @@ -382,6 +382,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= +golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ= +golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -404,6 +406,7 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= +golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -585,8 +588,8 @@ golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4f golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= -golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM= -golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/tools v0.12.1-0.20230815132531-74c255bcf846 h1:Vve/L0v7CXXuxUmaMGIEK/dEeq7uiqb5qBgQrZzIE7E= +golang.org/x/tools v0.12.1-0.20230815132531-74c255bcf846/go.mod h1:Sc0INKfu04TlqNoRA1hgpFZbhYXHPr4V5DzpSBTPqQM= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/pkg/cloud/cloud.go b/pkg/cloud/cloud.go index 011caa7dd..0bcc46905 100644 --- a/pkg/cloud/cloud.go +++ b/pkg/cloud/cloud.go @@ -55,6 +55,12 @@ type AccessPoint struct { // Capacity is used for testing purpose only // EFS does not consider capacity while provisioning new file systems or access points CapacityGiB int64 + PosixUser *PosixUser +} + +type PosixUser struct { + Gid int64 + Uid int64 } type AccessPointOptions struct { @@ -91,6 +97,7 @@ type Cloud interface { CreateAccessPoint(ctx context.Context, volumeName string, accessPointOpts *AccessPointOptions) (accessPoint *AccessPoint, err error) DeleteAccessPoint(ctx context.Context, accessPointId string) (err error) DescribeAccessPoint(ctx context.Context, accessPointId string) (accessPoint *AccessPoint, err error) + ListAccessPoints(ctx context.Context, fileSystemId string) (accessPoints []*AccessPoint, err error) DescribeFileSystem(ctx context.Context, fileSystemId string) (fs *FileSystem, err error) DescribeMountTargets(ctx context.Context, fileSystemId, az string) (fs *MountTarget, err error) } @@ -233,6 +240,37 @@ func (c *cloud) DescribeAccessPoint(ctx context.Context, accessPointId string) ( }, nil } +func (c *cloud) ListAccessPoints(ctx context.Context, fileSystemId string) (accessPoints []*AccessPoint, err error) { + describeAPInput := &efs.DescribeAccessPointsInput{ + FileSystemId: &fileSystemId, + } + res, err := c.efs.DescribeAccessPointsWithContext(ctx, describeAPInput) + if err != nil { + if isAccessDenied(err) { + return + } + if isFileSystemNotFound(err) { + return + } + err = fmt.Errorf("List Access Points failed: %v", err) + return + } + + for _, accessPointDescription := range res.AccessPoints { + accessPoint := &AccessPoint{ + AccessPointId: *accessPointDescription.AccessPointId, + FileSystemId: *accessPointDescription.FileSystemId, + PosixUser: &PosixUser{ + Gid: *accessPointDescription.PosixUser.Gid, + Uid: *accessPointDescription.PosixUser.Gid, + }, + } + accessPoints = append(accessPoints, accessPoint) + } + + return +} + func (c *cloud) DescribeFileSystem(ctx context.Context, fileSystemId string) (fs *FileSystem, err error) { describeFsInput := &efs.DescribeFileSystemsInput{FileSystemId: &fileSystemId} klog.V(5).Infof("Calling DescribeFileSystems with input: %+v", *describeFsInput) diff --git a/pkg/cloud/cloud_test.go b/pkg/cloud/cloud_test.go index 6e2079751..337e2441f 100644 --- a/pkg/cloud/cloud_test.go +++ b/pkg/cloud/cloud_test.go @@ -443,6 +443,125 @@ func TestDescribeAccessPoint(t *testing.T) { } } +func TestListAccessPoints(t *testing.T) { + var ( + fsId = "fs-abcd1234" + accessPointId = "ap-abc123" + Gid int64 = 1000 + Uid int64 = 1000 + ) + testCases := []struct { + name string + testFunc func(t *testing.T) + }{ + { + name: "Success", + testFunc: func(t *testing.T) { + mockctl := gomock.NewController(t) + mockEfs := mocks.NewMockEfs(mockctl) + c := &cloud{efs: mockEfs} + + output := &efs.DescribeAccessPointsOutput{ + AccessPoints: []*efs.AccessPointDescription{ + { + AccessPointId: aws.String(accessPointId), + FileSystemId: aws.String(fsId), + PosixUser: &efs.PosixUser{ + Gid: aws.Int64(Gid), + Uid: aws.Int64(Uid), + }, + }, + }, + NextToken: nil, + } + + ctx := context.Background() + mockEfs.EXPECT().DescribeAccessPointsWithContext(gomock.Eq(ctx), gomock.Any()).Return(output, nil) + res, err := c.ListAccessPoints(ctx, fsId) + if err != nil { + t.Fatalf("List Access Points failed: %v", err) + } + + if res == nil { + t.Fatal("Result is nil") + } + + if len(res) != 1 { + t.Fatalf("Expected only one AccessPoint in response but got: %v", res) + } + + mockctl.Finish() + }, + }, + { + name: "Success - multiple access points", + testFunc: func(t *testing.T) { + mockctl := gomock.NewController(t) + mockEfs := mocks.NewMockEfs(mockctl) + c := &cloud{efs: mockEfs} + + output := &efs.DescribeAccessPointsOutput{ + AccessPoints: []*efs.AccessPointDescription{ + { + AccessPointId: aws.String(accessPointId), + FileSystemId: aws.String(fsId), + PosixUser: &efs.PosixUser{ + Gid: aws.Int64(Gid), + Uid: aws.Int64(Uid), + }, + }, + { + AccessPointId: aws.String(accessPointId), + FileSystemId: aws.String(fsId), + PosixUser: &efs.PosixUser{ + Gid: aws.Int64(1001), + Uid: aws.Int64(1001), + }, + }, + }, + NextToken: nil, + } + + ctx := context.Background() + mockEfs.EXPECT().DescribeAccessPointsWithContext(gomock.Eq(ctx), gomock.Any()).Return(output, nil) + res, err := c.ListAccessPoints(ctx, fsId) + if err != nil { + t.Fatalf("List Access Points failed: %v", err) + } + + if res == nil { + t.Fatal("Result is nil") + } + + if len(res) != 2 { + t.Fatalf("Expected two AccessPoints in response but got: %v", res) + } + + mockctl.Finish() + }, + }, + { + name: "Fail - Access Denied", + testFunc: func(t *testing.T) { + mockctl := gomock.NewController(t) + mockEfs := mocks.NewMockEfs(mockctl) + c := &cloud{efs: mockEfs} + ctx := context.Background() + mockEfs.EXPECT().DescribeAccessPointsWithContext(gomock.Eq(ctx), gomock.Any()).Return(nil, awserr.New(AccessDeniedException, "Access Denied", errors.New("Access Denied"))) + _, err := c.ListAccessPoints(ctx, fsId) + if err == nil { + t.Fatalf("List Access Points should have failed: %v", err) + } + + mockctl.Finish() + }, + }, + } + for _, tc := range testCases { + t.Run(tc.name, tc.testFunc) + } +} + func TestDescribeFileSystem(t *testing.T) { var ( fsId = "fs-abcd1234" diff --git a/pkg/cloud/fakes.go b/pkg/cloud/fakes.go index d05e1a38e..5cbb2af5d 100644 --- a/pkg/cloud/fakes.go +++ b/pkg/cloud/fakes.go @@ -97,3 +97,10 @@ func (c *FakeCloudProvider) DescribeMountTargets(ctx context.Context, fileSystem return nil, ErrNotFound } + +func (c *FakeCloudProvider) ListAccessPoints(ctx context.Context, fileSystemId string) ([]*AccessPoint, error) { + accessPoints := []*AccessPoint{ + c.accessPoints[fileSystemId], + } + return accessPoints, nil +} diff --git a/pkg/driver/controller.go b/pkg/driver/controller.go index 78a324b37..be82c5445 100644 --- a/pkg/driver/controller.go +++ b/pkg/driver/controller.go @@ -99,13 +99,13 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) azName string basePath string err error - gid int + gid int64 gidMin int gidMax int localCloud cloud.Cloud provisioningMode string roleArn string - uid int + uid int64 ) //Parse parameters @@ -149,7 +149,7 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) uid = -1 if value, ok := volumeParams[Uid]; ok { - uid, err = strconv.Atoi(value) + uid, err = strconv.ParseInt(value, 10, 64) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "Failed to parse invalid %v: %v", Uid, err) } @@ -160,7 +160,7 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) gid = -1 if value, ok := volumeParams[Gid]; ok { - gid, err = strconv.Atoi(value) + gid, err = strconv.ParseInt(value, 10, 64) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "Failed to parse invalid %v: %v", Gid, err) } @@ -233,9 +233,9 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) return nil, status.Errorf(codes.Internal, "Failed to fetch File System info: %v", err) } - var allocatedGid int + var allocatedGid int64 if uid == -1 || gid == -1 { - allocatedGid, err = d.gidAllocator.getNextGid(accessPointsOptions.FileSystemId, gidMin, gidMax) + allocatedGid, err = d.gidAllocator.getNextGid(ctx, accessPointsOptions.FileSystemId, gidMin, gidMax) if err != nil { return nil, err } @@ -283,15 +283,12 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) } klog.Infof("Using %v as the access point directory.", rootDir) - accessPointsOptions.Uid = int64(uid) - accessPointsOptions.Gid = int64(gid) + accessPointsOptions.Uid = uid + accessPointsOptions.Gid = gid accessPointsOptions.DirectoryPath = rootDir accessPointId, err := localCloud.CreateAccessPoint(ctx, volName, accessPointsOptions) if err != nil { - if allocatedGid != 0 { - d.gidAllocator.releaseGid(accessPointsOptions.FileSystemId, gid) - } if err == cloud.ErrAccessDenied { return nil, status.Errorf(codes.Unauthenticated, "Access Denied. Please ensure you have the right AWS permissions: %v", err) } diff --git a/pkg/driver/controller_test.go b/pkg/driver/controller_test.go index 8cd57ee32..e407169bb 100644 --- a/pkg/driver/controller_test.go +++ b/pkg/driver/controller_test.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "regexp" + "strconv" "testing" "github.com/google/uuid" @@ -48,7 +49,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } req := &csi.CreateVolumeRequest{ @@ -113,7 +114,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } req := &csi.CreateVolumeRequest{ @@ -171,6 +172,322 @@ func TestCreateVolume(t *testing.T) { mockCtl.Finish() }, }, + { + name: "Success: avoiding GID collision", + testFunc: func(t *testing.T) { + mockCtl := gomock.NewController(t) + mockCloud := mocks.NewMockCloud(mockCtl) + + driver := &Driver{ + endpoint: endpoint, + cloud: mockCloud, + gidAllocator: NewGidAllocator(mockCloud), + } + + req := &csi.CreateVolumeRequest{ + Name: volumeName, + VolumeCapabilities: []*csi.VolumeCapability{ + stdVolCap, + }, + CapacityRange: &csi.CapacityRange{ + RequiredBytes: capacityRange, + }, + Parameters: map[string]string{ + ProvisioningMode: "efs-ap", + FsId: fsId, + DirectoryPerms: "777", + BasePath: "test", + GidMin: "1000", + GidMax: "1003", + }, + } + + ctx := context.Background() + fileSystem := &cloud.FileSystem{ + FileSystemId: fsId, + } + accessPoint := &cloud.AccessPoint{ + AccessPointId: apId, + FileSystemId: fsId, + } + accessPoints := []*cloud.AccessPoint{ + { + AccessPointId: apId, + FileSystemId: fsId, + PosixUser: &cloud.PosixUser{ + Gid: 1003, + Uid: 1003, + }, + }, + { + AccessPointId: apId, + FileSystemId: fsId, + PosixUser: &cloud.PosixUser{ + Gid: 1002, + Uid: 1002, + }, + }, + } + + var expectedGid int64 = 1001 //1003 and 1002 are taken, next available is 1001 + mockCloud.EXPECT().DescribeFileSystem(gomock.Eq(ctx), gomock.Any()).Return(fileSystem, nil) + mockCloud.EXPECT().ListAccessPoints(gomock.Eq(ctx), gomock.Any()).Return(accessPoints, nil) + mockCloud.EXPECT().CreateAccessPoint(gomock.Eq(ctx), gomock.Any(), gomock.Any()).Return(accessPoint, nil). + Do(func(ctx context.Context, volumeName string, accessPointOpts *cloud.AccessPointOptions) { + if accessPointOpts.Uid != expectedGid { + t.Fatalf("Uid mismatched. Expected: %v, actual: %v", expectedGid, accessPointOpts.Uid) + } + if accessPointOpts.Gid != expectedGid { + t.Fatalf("Gid mismatched. Expected: %v, actual: %v", expectedGid, accessPointOpts.Gid) + } + }) + + res, err := driver.CreateVolume(ctx, req) + + if err != nil { + t.Fatalf("CreateVolume failed: %v", err) + } + + if res.Volume == nil { + t.Fatal("Volume is nil") + } + + if res.Volume.VolumeId != volumeId { + t.Fatalf("Volume Id mismatched. Expected: %v, Actual: %v", volumeId, res.Volume.VolumeId) + } + mockCtl.Finish() + }, + }, + { + name: "Success: reuse released GID", + testFunc: func(t *testing.T) { + mockCtl := gomock.NewController(t) + mockCloud := mocks.NewMockCloud(mockCtl) + + driver := &Driver{ + endpoint: endpoint, + cloud: mockCloud, + gidAllocator: NewGidAllocator(mockCloud), + } + + req := &csi.CreateVolumeRequest{ + Name: volumeName, + VolumeCapabilities: []*csi.VolumeCapability{ + stdVolCap, + }, + CapacityRange: &csi.CapacityRange{ + RequiredBytes: capacityRange, + }, + Parameters: map[string]string{ + ProvisioningMode: "efs-ap", + FsId: fsId, + DirectoryPerms: "777", + BasePath: "test", + GidMin: "1000", + GidMax: "1004", + }, + } + + ctx := context.Background() + fileSystem := &cloud.FileSystem{ + FileSystemId: fsId, + } + ap1 := &cloud.AccessPoint{ + AccessPointId: apId, + FileSystemId: fsId, + PosixUser: &cloud.PosixUser{ + Gid: 1001, + Uid: 1001, + }, + } + ap2 := &cloud.AccessPoint{ + AccessPointId: apId, + FileSystemId: fsId, + PosixUser: &cloud.PosixUser{ + Gid: 1002, + Uid: 1002, + }, + } + ap3 := &cloud.AccessPoint{ + AccessPointId: apId, + FileSystemId: fsId, + PosixUser: &cloud.PosixUser{ + Gid: 1003, + Uid: 1003, + }, + } + ap4 := &cloud.AccessPoint{ + AccessPointId: apId, + FileSystemId: fsId, + PosixUser: &cloud.PosixUser{ + Gid: 1004, + Uid: 1004, + }, + } + + // Let allocator jump over some GIDS. + accessPoints := []*cloud.AccessPoint{ap3, ap4} + var expectedGid int64 = 1002 // 1003 and 1004 is taken. + + mockCloud.EXPECT().DescribeFileSystem(gomock.Eq(ctx), gomock.Any()).Return(fileSystem, nil) + mockCloud.EXPECT().ListAccessPoints(gomock.Eq(ctx), gomock.Any()).Return(accessPoints, nil) + mockCloud.EXPECT().CreateAccessPoint(gomock.Eq(ctx), gomock.Any(), gomock.Any()).Return(ap2, nil). + Do(func(ctx context.Context, volumeName string, accessPointOpts *cloud.AccessPointOptions) { + if accessPointOpts.Uid != expectedGid { + t.Fatalf("Uid mismatched. Expected: %v, actual: %v", expectedGid, accessPointOpts.Uid) + } + if accessPointOpts.Gid != expectedGid { + t.Fatalf("Gid mismatched. Expected: %v, actual: %v", expectedGid, accessPointOpts.Gid) + } + }) + + res, err := driver.CreateVolume(ctx, req) + + // 2. Simulate access point removal and verify their GIDs returned to allocator. + accessPoints = []*cloud.AccessPoint{} + expectedGid = 1004 // 1003 and 1004 are now free, if no GID return would happen allocator would pick 1001. + + mockCloud.EXPECT().DescribeFileSystem(gomock.Eq(ctx), gomock.Any()).Return(fileSystem, nil) + mockCloud.EXPECT().ListAccessPoints(gomock.Eq(ctx), gomock.Any()).Return(accessPoints, nil) + mockCloud.EXPECT().CreateAccessPoint(gomock.Eq(ctx), gomock.Any(), gomock.Any()).Return(ap3, nil). + Do(func(ctx context.Context, volumeName string, accessPointOpts *cloud.AccessPointOptions) { + if accessPointOpts.Uid != expectedGid { + t.Fatalf("Uid mismatched. Expected: %v, actual: %v", expectedGid, accessPointOpts.Uid) + } + if accessPointOpts.Gid != expectedGid { + t.Fatalf("Gid mismatched. Expected: %v, actual: %v", expectedGid, accessPointOpts.Gid) + } + }) + + res, err = driver.CreateVolume(ctx, req) + //// + accessPoints = []*cloud.AccessPoint{ap1, ap4} + + expectedGid = 1003 + mockCloud.EXPECT().DescribeFileSystem(gomock.Eq(ctx), gomock.Any()).Return(fileSystem, nil) + mockCloud.EXPECT().ListAccessPoints(gomock.Eq(ctx), gomock.Any()).Return(accessPoints, nil) + mockCloud.EXPECT().CreateAccessPoint(gomock.Eq(ctx), gomock.Any(), gomock.Any()).Return(ap2, nil). + Do(func(ctx context.Context, volumeName string, accessPointOpts *cloud.AccessPointOptions) { + if accessPointOpts.Uid != expectedGid { + t.Fatalf("Uid mismatched. Expected: %v, actual: %v", expectedGid, accessPointOpts.Uid) + } + if accessPointOpts.Gid != expectedGid { + t.Fatalf("Gid mismatched. Expected: %v, actual: %v", expectedGid, accessPointOpts.Gid) + } + }) + + res, err = driver.CreateVolume(ctx, req) + + if err != nil { + t.Fatalf("CreateVolume failed: %v", err) + } + + if res.Volume == nil { + t.Fatal("Volume is nil") + } + + if res.Volume.VolumeId != volumeId { + t.Fatalf("Volume Id mismatched. Expected: %v, Actual: %v", volumeId, res.Volume.VolumeId) + } + mockCtl.Finish() + }, + }, + { + name: "Success: EFS access point limit", + testFunc: func(t *testing.T) { + mockCtl := gomock.NewController(t) + mockCloud := mocks.NewMockCloud(mockCtl) + + driver := &Driver{ + endpoint: endpoint, + cloud: mockCloud, + gidAllocator: NewGidAllocator(mockCloud), + } + + req := &csi.CreateVolumeRequest{ + Name: volumeName, + VolumeCapabilities: []*csi.VolumeCapability{ + stdVolCap, + }, + CapacityRange: &csi.CapacityRange{ + RequiredBytes: capacityRange, + }, + Parameters: map[string]string{ + ProvisioningMode: "efs-ap", + FsId: fsId, + DirectoryPerms: "777", + BasePath: "test", + GidMin: "1000", + GidMax: "1200", + }, + } + + ctx := context.Background() + fileSystem := &cloud.FileSystem{ + FileSystemId: fsId, + } + + accessPoints := []*cloud.AccessPoint{} + for i := 0; i < 119; i++ { + gidMax, err := strconv.Atoi(req.Parameters[GidMax]) + if err != nil { + t.Fatalf("Failed to convert GidMax Parameter to int.") + } + userGid := gidMax - i + ap := &cloud.AccessPoint{ + AccessPointId: apId, + FileSystemId: fsId, + PosixUser: &cloud.PosixUser{ + Gid: int64(userGid), + Uid: int64(userGid), + }, + } + accessPoints = append(accessPoints, ap) + } + + lastAccessPoint := &cloud.AccessPoint{ + AccessPointId: apId, + FileSystemId: fsId, + PosixUser: &cloud.PosixUser{ + Gid: 1081, + Uid: 1081, + }, + } + + expectedGid := 1081 + mockCloud.EXPECT().DescribeFileSystem(gomock.Eq(ctx), gomock.Any()).Return(fileSystem, nil) + mockCloud.EXPECT().ListAccessPoints(gomock.Eq(ctx), gomock.Any()).Return(accessPoints, nil) + mockCloud.EXPECT().CreateAccessPoint(gomock.Eq(ctx), gomock.Any(), gomock.Any()).Return(lastAccessPoint, nil). + Do(func(ctx context.Context, volumeName string, accessPointOpts *cloud.AccessPointOptions) { + if accessPointOpts.Uid != int64(expectedGid) { + t.Fatalf("Uid mismatched. Expected: %v, actual: %v", expectedGid, accessPointOpts.Uid) + } + if accessPointOpts.Gid != int64(expectedGid) { + t.Fatalf("Gid mismatched. Expected: %v, actual: %v", expectedGid, accessPointOpts.Gid) + } + }) + + var err error + + // Allocate last available GID + _, err = driver.CreateVolume(ctx, req) + if err != nil { + t.Fatalf("CreateVolume failed.") + } + + accessPoints = append(accessPoints, lastAccessPoint) + mockCloud.EXPECT().DescribeFileSystem(gomock.Eq(ctx), gomock.Any()).Return(fileSystem, nil) + mockCloud.EXPECT().ListAccessPoints(gomock.Eq(ctx), gomock.Any()).Return(accessPoints, nil) + mockCloud.EXPECT().CreateAccessPoint(gomock.Eq(ctx), gomock.Any(), gomock.Any()).Return(lastAccessPoint, nil).AnyTimes() + + // All 120 GIDs are taken now, internal limit should take effect causing CreateVolume to fail. + _, err = driver.CreateVolume(ctx, req) + if err == nil { + t.Fatalf("CreateVolume should have failed.") + } + mockCtl.Finish() + }, + }, { name: "Success: Normal flow", testFunc: func(t *testing.T) { @@ -180,7 +497,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), tags: parseTagsFromStr(""), } @@ -209,8 +526,14 @@ func TestCreateVolume(t *testing.T) { accessPoint := &cloud.AccessPoint{ AccessPointId: apId, FileSystemId: fsId, + PosixUser: &cloud.PosixUser{ + Gid: 1000, + Uid: 1000, + }, } + accessPoints := []*cloud.AccessPoint{accessPoint} mockCloud.EXPECT().DescribeFileSystem(gomock.Eq(ctx), gomock.Any()).Return(fileSystem, nil) + mockCloud.EXPECT().ListAccessPoints(gomock.Eq(ctx), gomock.Any()).Return(accessPoints, nil) mockCloud.EXPECT().CreateAccessPoint(gomock.Eq(ctx), gomock.Any(), gomock.Any()).Return(accessPoint, nil) res, err := driver.CreateVolume(ctx, req) @@ -238,7 +561,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } req := &csi.CreateVolumeRequest{ @@ -264,8 +587,13 @@ func TestCreateVolume(t *testing.T) { accessPoint := &cloud.AccessPoint{ AccessPointId: apId, FileSystemId: fsId, + PosixUser: &cloud.PosixUser{ + Gid: DefaultGidMin - 1, //use GID that is not in default range + }, } + accessPoints := []*cloud.AccessPoint{accessPoint} mockCloud.EXPECT().DescribeFileSystem(gomock.Eq(ctx), gomock.Any()).Return(fileSystem, nil) + mockCloud.EXPECT().ListAccessPoints(gomock.Eq(ctx), gomock.Any()).Return(accessPoints, nil) mockCloud.EXPECT().CreateAccessPoint(gomock.Eq(ctx), gomock.Any(), gomock.Any()).Return(accessPoint, nil) res, err := driver.CreateVolume(ctx, req) @@ -293,7 +621,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), tags: parseTagsFromStr("cluster:efs"), } @@ -321,8 +649,14 @@ func TestCreateVolume(t *testing.T) { accessPoint := &cloud.AccessPoint{ AccessPointId: apId, FileSystemId: fsId, + PosixUser: &cloud.PosixUser{ + Gid: 1000, + Uid: 1000, + }, } + accessPoints := []*cloud.AccessPoint{accessPoint} mockCloud.EXPECT().DescribeFileSystem(gomock.Eq(ctx), gomock.Any()).Return(fileSystem, nil) + mockCloud.EXPECT().ListAccessPoints(gomock.Eq(ctx), gomock.Any()).Return(accessPoints, nil) mockCloud.EXPECT().CreateAccessPoint(gomock.Eq(ctx), gomock.Any(), gomock.Any()).Return(accessPoint, nil) res, err := driver.CreateVolume(ctx, req) @@ -350,7 +684,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), tags: parseTagsFromStr("cluster-efs"), } @@ -378,8 +712,14 @@ func TestCreateVolume(t *testing.T) { accessPoint := &cloud.AccessPoint{ AccessPointId: apId, FileSystemId: fsId, + PosixUser: &cloud.PosixUser{ + Gid: 1000, + Uid: 1000, + }, } + accessPoints := []*cloud.AccessPoint{accessPoint} mockCloud.EXPECT().DescribeFileSystem(gomock.Eq(ctx), gomock.Any()).Return(fileSystem, nil) + mockCloud.EXPECT().ListAccessPoints(gomock.Eq(ctx), gomock.Any()).Return(accessPoints, nil) mockCloud.EXPECT().CreateAccessPoint(gomock.Eq(ctx), gomock.Any(), gomock.Any()).Return(accessPoint, nil) res, err := driver.CreateVolume(ctx, req) @@ -407,7 +747,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), tags: parseTagsFromStr(""), } @@ -444,6 +784,7 @@ func TestCreateVolume(t *testing.T) { FileSystemId: fsId, } mockCloud.EXPECT().DescribeFileSystem(gomock.Eq(ctx), gomock.Any()).Return(fileSystem, nil) + mockCloud.EXPECT().ListAccessPoints(gomock.Eq(ctx), gomock.Any()).Return(nil, nil) mockCloud.EXPECT().CreateAccessPoint(gomock.Eq(ctx), gomock.Any(), gomock.Any()).Return(accessPoint, nil). Do(func(ctx context.Context, volumeName string, accessPointOpts *cloud.AccessPointOptions) { @@ -480,7 +821,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), tags: parseTagsFromStr(""), } @@ -515,6 +856,7 @@ func TestCreateVolume(t *testing.T) { FileSystemId: fsId, } mockCloud.EXPECT().DescribeFileSystem(gomock.Eq(ctx), gomock.Any()).Return(fileSystem, nil) + mockCloud.EXPECT().ListAccessPoints(gomock.Eq(ctx), gomock.Any()).Return(nil, nil) mockCloud.EXPECT().CreateAccessPoint(gomock.Eq(ctx), gomock.Any(), gomock.Any()).Return(accessPoint, nil). Do(func(ctx context.Context, volumeName string, accessPointOpts *cloud.AccessPointOptions) { @@ -551,7 +893,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), tags: parseTagsFromStr(""), } @@ -589,6 +931,7 @@ func TestCreateVolume(t *testing.T) { FileSystemId: fsId, } mockCloud.EXPECT().DescribeFileSystem(gomock.Eq(ctx), gomock.Any()).Return(fileSystem, nil) + mockCloud.EXPECT().ListAccessPoints(gomock.Eq(ctx), gomock.Any()).Return(nil, nil) mockCloud.EXPECT().CreateAccessPoint(gomock.Eq(ctx), gomock.Any(), gomock.Any()).Return(accessPoint, nil). Do(func(ctx context.Context, volumeName string, accessPointOpts *cloud.AccessPointOptions) { @@ -625,7 +968,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), tags: parseTagsFromStr(""), } @@ -663,6 +1006,7 @@ func TestCreateVolume(t *testing.T) { FileSystemId: fsId, } mockCloud.EXPECT().DescribeFileSystem(gomock.Eq(ctx), gomock.Any()).Return(fileSystem, nil) + mockCloud.EXPECT().ListAccessPoints(gomock.Eq(ctx), gomock.Any()).Return(nil, nil) mockCloud.EXPECT().CreateAccessPoint(gomock.Eq(ctx), gomock.Any(), gomock.Any()).Return(accessPoint, nil). Do(func(ctx context.Context, volumeName string, accessPointOpts *cloud.AccessPointOptions) { @@ -700,7 +1044,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), tags: parseTagsFromStr(""), } @@ -736,6 +1080,7 @@ func TestCreateVolume(t *testing.T) { FileSystemId: fsId, } mockCloud.EXPECT().DescribeFileSystem(gomock.Eq(ctx), gomock.Any()).Return(fileSystem, nil) + mockCloud.EXPECT().ListAccessPoints(gomock.Eq(ctx), gomock.Any()).Return(nil, nil) mockCloud.EXPECT().CreateAccessPoint(gomock.Eq(ctx), gomock.Any(), gomock.Any()).Return(accessPoint, nil). Do(func(ctx context.Context, volumeName string, accessPointOpts *cloud.AccessPointOptions) { @@ -772,7 +1117,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), tags: parseTagsFromStr(""), } @@ -804,6 +1149,7 @@ func TestCreateVolume(t *testing.T) { FileSystemId: fsId, } mockCloud.EXPECT().DescribeFileSystem(gomock.Eq(ctx), gomock.Any()).Return(fileSystem, nil) + mockCloud.EXPECT().ListAccessPoints(gomock.Eq(ctx), gomock.Any()).Return(nil, nil) mockCloud.EXPECT().CreateAccessPoint(gomock.Eq(ctx), gomock.Any(), gomock.Any()).Return(accessPoint, nil). Do(func(ctx context.Context, volumeName string, accessPointOpts *cloud.AccessPointOptions) { @@ -840,7 +1186,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), tags: parseTagsFromStr(""), } @@ -873,6 +1219,7 @@ func TestCreateVolume(t *testing.T) { FileSystemId: fsId, } mockCloud.EXPECT().DescribeFileSystem(gomock.Eq(ctx), gomock.Any()).Return(fileSystem, nil) + mockCloud.EXPECT().ListAccessPoints(gomock.Eq(ctx), gomock.Any()).Return(nil, nil) mockCloud.EXPECT().CreateAccessPoint(gomock.Eq(ctx), gomock.Any(), gomock.Any()).Return(accessPoint, nil). Do(func(ctx context.Context, volumeName string, accessPointOpts *cloud.AccessPointOptions) { @@ -909,7 +1256,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } req := &csi.CreateVolumeRequest{ @@ -937,7 +1284,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } req := &csi.CreateVolumeRequest{ @@ -966,7 +1313,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } req := &csi.CreateVolumeRequest{ @@ -998,7 +1345,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } req := &csi.CreateVolumeRequest{ @@ -1040,7 +1387,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), tags: parseTagsFromStr(""), } @@ -1088,7 +1435,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } req := &csi.CreateVolumeRequest{ @@ -1130,7 +1477,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } req := &csi.CreateVolumeRequest{ @@ -1165,7 +1512,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } req := &csi.CreateVolumeRequest{ @@ -1199,7 +1546,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } req := &csi.CreateVolumeRequest{ @@ -1233,7 +1580,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } req := &csi.CreateVolumeRequest{ @@ -1268,7 +1615,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } req := &csi.CreateVolumeRequest{ @@ -1304,7 +1651,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } req := &csi.CreateVolumeRequest{ @@ -1340,7 +1687,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } req := &csi.CreateVolumeRequest{ @@ -1376,7 +1723,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } req := &csi.CreateVolumeRequest{ @@ -1412,7 +1759,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } req := &csi.CreateVolumeRequest{ @@ -1448,7 +1795,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } req := &csi.CreateVolumeRequest{ @@ -1484,7 +1831,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } req := &csi.CreateVolumeRequest{ @@ -1521,7 +1868,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } req := &csi.CreateVolumeRequest{ @@ -1558,7 +1905,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } req := &csi.CreateVolumeRequest{ @@ -1594,7 +1941,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } req := &csi.CreateVolumeRequest{ @@ -1630,7 +1977,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } req := &csi.CreateVolumeRequest{ @@ -1668,7 +2015,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } req := &csi.CreateVolumeRequest{ @@ -1706,7 +2053,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } req := &csi.CreateVolumeRequest{ @@ -1744,7 +2091,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } req := &csi.CreateVolumeRequest{ @@ -1769,6 +2116,7 @@ func TestCreateVolume(t *testing.T) { FileSystemId: fsId, } mockCloud.EXPECT().DescribeFileSystem(gomock.Eq(ctx), gomock.Any()).Return(fileSystem, nil) + mockCloud.EXPECT().ListAccessPoints(gomock.Eq(ctx), gomock.Any()).Return([]*cloud.AccessPoint{}, nil) mockCloud.EXPECT().CreateAccessPoint(gomock.Eq(ctx), gomock.Any(), gomock.Any()).Return(nil, errors.New("CreateAccessPoint call failed")) _, err := driver.CreateVolume(ctx, req) if err == nil { @@ -1786,7 +2134,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } req := &csi.CreateVolumeRequest{ @@ -1811,6 +2159,7 @@ func TestCreateVolume(t *testing.T) { FileSystemId: fsId, } mockCloud.EXPECT().DescribeFileSystem(gomock.Eq(ctx), gomock.Any()).Return(fileSystem, nil) + mockCloud.EXPECT().ListAccessPoints(gomock.Eq(ctx), gomock.Any()).Return([]*cloud.AccessPoint{}, nil) mockCloud.EXPECT().CreateAccessPoint(gomock.Eq(ctx), gomock.Any(), gomock.Any()).Return(nil, cloud.ErrAccessDenied) _, err := driver.CreateVolume(ctx, req) if err == nil { @@ -1828,7 +2177,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } req := &csi.CreateVolumeRequest{ @@ -1852,18 +2201,29 @@ func TestCreateVolume(t *testing.T) { fileSystem := &cloud.FileSystem{ FileSystemId: fsId, } - accessPoint := &cloud.AccessPoint{ + ap1 := &cloud.AccessPoint{ AccessPointId: apId, FileSystemId: fsId, + PosixUser: &cloud.PosixUser{ + Gid: 1000, + Uid: 1000, + }, + } + ap2 := &cloud.AccessPoint{ + AccessPointId: apId, + FileSystemId: fsId, + PosixUser: &cloud.PosixUser{ + Gid: 1001, + Uid: 1001, + }, } mockCloud.EXPECT().DescribeFileSystem(gomock.Eq(ctx), gomock.Any()).Return(fileSystem, nil).AnyTimes() - mockCloud.EXPECT().CreateAccessPoint(gomock.Eq(ctx), gomock.Any(), gomock.Any()).Return(accessPoint, nil).AnyTimes() + mockCloud.EXPECT().ListAccessPoints(gomock.Eq(ctx), gomock.Any()).Return([]*cloud.AccessPoint{ap1, ap2}, nil).AnyTimes() + mockCloud.EXPECT().CreateAccessPoint(gomock.Eq(ctx), gomock.Any(), gomock.Any()).Return(ap2, nil).AnyTimes() var err error - // Input grants 2 GIDS, third CreateVolume call should result in error - for i := 0; i < 3; i++ { - _, err = driver.CreateVolume(ctx, req) - } + // All GIDs from available range are taken, CreateVolume should fail. + _, err = driver.CreateVolume(ctx, req) if err == nil { t.Fatalf("CreateVolume did not fail") @@ -1880,7 +2240,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), tags: parseTagsFromStr(""), } @@ -1928,7 +2288,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } req := &csi.CreateVolumeRequest{ @@ -1954,6 +2314,7 @@ func TestCreateVolume(t *testing.T) { } mockCloud.EXPECT().DescribeFileSystem(gomock.Eq(ctx), gomock.Any()).Return(fileSystem, nil) + mockCloud.EXPECT().ListAccessPoints(gomock.Eq(ctx), gomock.Any()).Return(nil, nil) _, err := driver.CreateVolume(ctx, req) if err == nil { @@ -1976,7 +2337,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } req := &csi.CreateVolumeRequest{ @@ -2002,6 +2363,7 @@ func TestCreateVolume(t *testing.T) { } mockCloud.EXPECT().DescribeFileSystem(gomock.Eq(ctx), gomock.Any()).Return(fileSystem, nil) + mockCloud.EXPECT().ListAccessPoints(gomock.Eq(ctx), gomock.Any()).Return(nil, nil) _, err := driver.CreateVolume(ctx, req) if err == nil { @@ -2024,7 +2386,7 @@ func TestCreateVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } req := &csi.CreateVolumeRequest{ @@ -2050,6 +2412,7 @@ func TestCreateVolume(t *testing.T) { } mockCloud.EXPECT().DescribeFileSystem(gomock.Eq(ctx), gomock.Any()).Return(fileSystem, nil) + mockCloud.EXPECT().ListAccessPoints(gomock.Eq(ctx), gomock.Any()).Return(nil, nil) _, err := driver.CreateVolume(ctx, req) if err == nil { @@ -2089,7 +2452,7 @@ func TestDeleteVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } req := &csi.DeleteVolumeRequest{ @@ -2116,7 +2479,7 @@ func TestDeleteVolume(t *testing.T) { endpoint: endpoint, cloud: mockCloud, mounter: mockMounter, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), deleteAccessPointRootDir: true, } @@ -2155,7 +2518,7 @@ func TestDeleteVolume(t *testing.T) { endpoint: endpoint, cloud: mockCloud, mounter: mockMounter, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), deleteAccessPointRootDir: true, } @@ -2183,7 +2546,7 @@ func TestDeleteVolume(t *testing.T) { endpoint: endpoint, cloud: mockCloud, mounter: mockMounter, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), deleteAccessPointRootDir: true, } @@ -2211,7 +2574,7 @@ func TestDeleteVolume(t *testing.T) { endpoint: endpoint, cloud: mockCloud, mounter: mockMounter, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), deleteAccessPointRootDir: true, } @@ -2239,7 +2602,7 @@ func TestDeleteVolume(t *testing.T) { endpoint: endpoint, cloud: mockCloud, mounter: mockMounter, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), deleteAccessPointRootDir: true, } @@ -2275,7 +2638,7 @@ func TestDeleteVolume(t *testing.T) { endpoint: endpoint, cloud: mockCloud, mounter: mockMounter, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), deleteAccessPointRootDir: true, } @@ -2312,7 +2675,7 @@ func TestDeleteVolume(t *testing.T) { endpoint: endpoint, cloud: mockCloud, mounter: mockMounter, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), deleteAccessPointRootDir: true, } @@ -2348,7 +2711,7 @@ func TestDeleteVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } req := &csi.DeleteVolumeRequest{ @@ -2373,7 +2736,7 @@ func TestDeleteVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } req := &csi.DeleteVolumeRequest{ @@ -2398,7 +2761,7 @@ func TestDeleteVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } req := &csi.DeleteVolumeRequest{ @@ -2423,7 +2786,7 @@ func TestDeleteVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } req := &csi.DeleteVolumeRequest{ @@ -2447,7 +2810,7 @@ func TestDeleteVolume(t *testing.T) { driver := &Driver{ endpoint: endpoint, cloud: mockCloud, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), tags: parseTagsFromStr(""), } diff --git a/pkg/driver/driver.go b/pkg/driver/driver.go index c7a974b7d..0fa414c14 100644 --- a/pkg/driver/driver.go +++ b/pkg/driver/driver.go @@ -69,7 +69,7 @@ func NewDriver(endpoint, efsUtilsCfgPath, efsUtilsStaticFilesPath, tags string, volMetricsOptIn: volMetricsOptIn, volMetricsRefreshPeriod: volMetricsRefreshPeriod, volMetricsFsRateLimit: volMetricsFsRateLimit, - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(cloud), deleteAccessPointRootDir: deleteAccessPointRootDir, tags: parseTagsFromStr(strings.TrimSpace(tags)), } diff --git a/pkg/driver/gid_allocator.go b/pkg/driver/gid_allocator.go index f672cdf08..f7d5063e7 100644 --- a/pkg/driver/gid_allocator.go +++ b/pkg/driver/gid_allocator.go @@ -1,84 +1,57 @@ package driver import ( - "container/heap" - "sync" - + "context" + "fmt" + "github.com/kubernetes-sigs/aws-efs-csi-driver/pkg/cloud" + "golang.org/x/exp/slices" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "k8s.io/klog/v2" + "sync" ) -type IntHeap []int - -func (h IntHeap) Len() int { - return len(h) -} -func (h IntHeap) Less(i, j int) bool { - return h[i] < h[j] -} -func (h IntHeap) Swap(i, j int) { - h[i], h[j] = h[j], h[i] -} +var ACCESS_POINT_PER_FS_LIMIT int = 120 -func (h *IntHeap) Push(x interface{}) { - *h = append(*h, x.(int)) -} - -func (h *IntHeap) Pop() interface{} { - old := *h - n := len(old) - x := old[n-1] - *h = old[0 : n-1] - return x +type FilesystemID struct { + gidMin int + gidMax int } type GidAllocator struct { - fsIdGidMap map[string]*IntHeap + cloud cloud.Cloud + fsIdGidMap map[string]*FilesystemID mu sync.Mutex } -func NewGidAllocator() GidAllocator { +func NewGidAllocator(cloud cloud.Cloud) GidAllocator { return GidAllocator{ - fsIdGidMap: make(map[string]*IntHeap), + cloud: cloud, + fsIdGidMap: make(map[string]*FilesystemID), } } // Retrieves the next available GID -func (g *GidAllocator) getNextGid(fsId string, gidMin, gidMax int) (int, error) { +func (g *GidAllocator) getNextGid(ctx context.Context, fsId string, gidMin, gidMax int) (int64, error) { g.mu.Lock() defer g.mu.Unlock() klog.V(5).Infof("Recieved getNextGid for fsId: %v, min: %v, max: %v", fsId, gidMin, gidMax) - if _, ok := g.fsIdGidMap[fsId]; !ok { - klog.V(5).Infof("FS Id doesn't exist, initializing...") - g.initFsId(fsId, gidMin, gidMax) + usedGids, err := g.getUsedGids(ctx, fsId) + if err != nil { + return 0, status.Errorf(codes.Internal, "Failed to discover used GIDs for filesystem: %v: %v ", fsId, err) } - gidHeap := g.fsIdGidMap[fsId] + gid, err := getNextUnusedGid(usedGids, gidMin, gidMax) - if gidHeap.Len() > 0 { - return heap.Pop(gidHeap).(int), nil - } else { - return 0, status.Errorf(codes.Internal, "Failed to locate a free GID for given the file system: %v. "+ + if err != nil { + return 0, status.Errorf(codes.Internal, "Failed to locate a free GID for given file system: %v. "+ "Please create a new storage class with a new file-system", fsId) } -} -func (g *GidAllocator) releaseGid(fsId string, gid int) { - g.mu.Lock() - defer g.mu.Unlock() + return int64(gid), nil - gidHeap := g.fsIdGidMap[fsId] - gidHeap.Push(gid) -} - -// Creates an entry fsIdGidMap if fsId does not exist. -func (g *GidAllocator) initFsId(fsId string, gidMin, gidMax int) { - h := initHeap(gidMin, gidMax) - heap.Init(h) - g.fsIdGidMap[fsId] = h } func (g *GidAllocator) removeFsId(fsId string) { @@ -87,13 +60,58 @@ func (g *GidAllocator) removeFsId(fsId string) { delete(g.fsIdGidMap, fsId) } -// Initializes a heap inclusive of min & max -func initHeap(min, max int) *IntHeap { - h := make(IntHeap, max-min+1) - val := min - for i := range h { - h[i] = val - val += 1 +func (g *GidAllocator) getUsedGids(ctx context.Context, fsId string) (gids []int64, err error) { + gids = []int64{} + accessPoints, err := g.cloud.ListAccessPoints(ctx, fsId) + if err != nil { + err = fmt.Errorf("failed to list access points: %v", err) + return + } + if len(accessPoints) == 0 { + return gids, nil + } + for _, ap := range accessPoints { + // This should happen only in tests - skip nil pointers. + if ap == nil { + continue + } + if ap != nil && ap.PosixUser == nil { + err = fmt.Errorf("failed to discover used GID because PosixUser is nil for AccessPoint: %s", ap.AccessPointId) + return + } + gids = append(gids, ap.PosixUser.Gid) + } + klog.V(5).Infof("Discovered used GIDs: %+v for FS ID: %v", gids, fsId) + return +} + +func getNextUnusedGid(usedGids []int64, gidMin, gidMax int) (nextGid int, err error) { + requestedRange := gidMax - gidMin + + if requestedRange > ACCESS_POINT_PER_FS_LIMIT { + klog.Warningf("Requested GID range (%v:%v) exceeds EFS Access Point limit (%v) per Filesystem. Driver will not allocate GIDs outside of this limit.", gidMin, gidMax, ACCESS_POINT_PER_FS_LIMIT) + gidMin = gidMax - ACCESS_POINT_PER_FS_LIMIT + } + + var lookup func(usedGids []int64) + lookup = func(usedGids []int64) { + for gid := gidMax; gid > gidMin; gid-- { + if !slices.Contains(usedGids, int64(gid)) { + nextGid = gid + return + } + klog.V(5).Infof("Allocator found GID which is already in use: %v - trying next one.", nextGid) + } + return } - return &h + + nextGid = -1 + lookup(usedGids) + if nextGid == -1 { + err = fmt.Errorf("allocator failed to find available GID") + return + } + + klog.V(5).Infof("Allocator found unused GID: %v", nextGid) + return } diff --git a/pkg/driver/mocks/mock_cloud.go b/pkg/driver/mocks/mock_cloud.go index 85a40d1ce..557e5b9a5 100644 --- a/pkg/driver/mocks/mock_cloud.go +++ b/pkg/driver/mocks/mock_cloud.go @@ -122,3 +122,18 @@ func (mr *MockCloudMockRecorder) GetMetadata() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMetadata", reflect.TypeOf((*MockCloud)(nil).GetMetadata)) } + +// ListAccessPoints mocks base method. +func (m *MockCloud) ListAccessPoints(arg0 context.Context, arg1 string) ([]*cloud.AccessPoint, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListAccessPoints", arg0, arg1) + ret0, _ := ret[0].([]*cloud.AccessPoint) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListAccessPoints indicates an expected call of ListAccessPoints. +func (mr *MockCloudMockRecorder) ListAccessPoints(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAccessPoints", reflect.TypeOf((*MockCloud)(nil).ListAccessPoints), arg0, arg1) +} diff --git a/pkg/driver/sanity_test.go b/pkg/driver/sanity_test.go index d1482e63d..c2c55bf65 100644 --- a/pkg/driver/sanity_test.go +++ b/pkg/driver/sanity_test.go @@ -69,16 +69,17 @@ func TestSanityEFSCSI(t *testing.T) { nodeCaps := SetNodeCapOptInFeatures(true) mockCtrl := gomock.NewController(t) + mockCloud := cloud.NewFakeCloudProvider() drv := Driver{ endpoint: endpoint, nodeID: "sanity", mounter: NewFakeMounter(), efsWatchdog: &mockWatchdog{}, - cloud: cloud.NewFakeCloudProvider(), + cloud: mockCloud, nodeCaps: nodeCaps, volMetricsOptIn: true, volStatter: NewVolStatter(), - gidAllocator: NewGidAllocator(), + gidAllocator: NewGidAllocator(mockCloud), } defer func() { if r := recover(); r != nil { diff --git a/vendor/golang.org/x/exp/LICENSE b/vendor/golang.org/x/exp/LICENSE new file mode 100644 index 000000000..6a66aea5e --- /dev/null +++ b/vendor/golang.org/x/exp/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/golang.org/x/exp/PATENTS b/vendor/golang.org/x/exp/PATENTS new file mode 100644 index 000000000..733099041 --- /dev/null +++ b/vendor/golang.org/x/exp/PATENTS @@ -0,0 +1,22 @@ +Additional IP Rights Grant (Patents) + +"This implementation" means the copyrightable works distributed by +Google as part of the Go project. + +Google hereby grants to You a perpetual, worldwide, non-exclusive, +no-charge, royalty-free, irrevocable (except as stated in this section) +patent license to make, have made, use, offer to sell, sell, import, +transfer and otherwise run, modify and propagate the contents of this +implementation of Go, where such license applies only to those patent +claims, both currently owned or controlled by Google and acquired in +the future, licensable by Google that are necessarily infringed by this +implementation of Go. This grant does not include claims that would be +infringed only as a consequence of further modification of this +implementation. If you or your agent or exclusive licensee institute or +order or agree to the institution of patent litigation against any +entity (including a cross-claim or counterclaim in a lawsuit) alleging +that this implementation of Go or any code incorporated within this +implementation of Go constitutes direct or contributory patent +infringement, or inducement of patent infringement, then any patent +rights granted to you under this License for this implementation of Go +shall terminate as of the date such litigation is filed. diff --git a/vendor/golang.org/x/exp/constraints/constraints.go b/vendor/golang.org/x/exp/constraints/constraints.go new file mode 100644 index 000000000..2c033dff4 --- /dev/null +++ b/vendor/golang.org/x/exp/constraints/constraints.go @@ -0,0 +1,50 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package constraints defines a set of useful constraints to be used +// with type parameters. +package constraints + +// Signed is a constraint that permits any signed integer type. +// If future releases of Go add new predeclared signed integer types, +// this constraint will be modified to include them. +type Signed interface { + ~int | ~int8 | ~int16 | ~int32 | ~int64 +} + +// Unsigned is a constraint that permits any unsigned integer type. +// If future releases of Go add new predeclared unsigned integer types, +// this constraint will be modified to include them. +type Unsigned interface { + ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr +} + +// Integer is a constraint that permits any integer type. +// If future releases of Go add new predeclared integer types, +// this constraint will be modified to include them. +type Integer interface { + Signed | Unsigned +} + +// Float is a constraint that permits any floating-point type. +// If future releases of Go add new predeclared floating-point types, +// this constraint will be modified to include them. +type Float interface { + ~float32 | ~float64 +} + +// Complex is a constraint that permits any complex numeric type. +// If future releases of Go add new predeclared complex numeric types, +// this constraint will be modified to include them. +type Complex interface { + ~complex64 | ~complex128 +} + +// Ordered is a constraint that permits any ordered type: any type +// that supports the operators < <= >= >. +// If future releases of Go add new ordered types, +// this constraint will be modified to include them. +type Ordered interface { + Integer | Float | ~string +} diff --git a/vendor/golang.org/x/exp/slices/cmp.go b/vendor/golang.org/x/exp/slices/cmp.go new file mode 100644 index 000000000..fbf1934a0 --- /dev/null +++ b/vendor/golang.org/x/exp/slices/cmp.go @@ -0,0 +1,44 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package slices + +import "golang.org/x/exp/constraints" + +// min is a version of the predeclared function from the Go 1.21 release. +func min[T constraints.Ordered](a, b T) T { + if a < b || isNaN(a) { + return a + } + return b +} + +// max is a version of the predeclared function from the Go 1.21 release. +func max[T constraints.Ordered](a, b T) T { + if a > b || isNaN(a) { + return a + } + return b +} + +// cmpLess is a copy of cmp.Less from the Go 1.21 release. +func cmpLess[T constraints.Ordered](x, y T) bool { + return (isNaN(x) && !isNaN(y)) || x < y +} + +// cmpCompare is a copy of cmp.Compare from the Go 1.21 release. +func cmpCompare[T constraints.Ordered](x, y T) int { + xNaN := isNaN(x) + yNaN := isNaN(y) + if xNaN && yNaN { + return 0 + } + if xNaN || x < y { + return -1 + } + if yNaN || x > y { + return +1 + } + return 0 +} diff --git a/vendor/golang.org/x/exp/slices/slices.go b/vendor/golang.org/x/exp/slices/slices.go new file mode 100644 index 000000000..5e8158bba --- /dev/null +++ b/vendor/golang.org/x/exp/slices/slices.go @@ -0,0 +1,499 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package slices defines various functions useful with slices of any type. +package slices + +import ( + "unsafe" + + "golang.org/x/exp/constraints" +) + +// Equal reports whether two slices are equal: the same length and all +// elements equal. If the lengths are different, Equal returns false. +// Otherwise, the elements are compared in increasing index order, and the +// comparison stops at the first unequal pair. +// Floating point NaNs are not considered equal. +func Equal[S ~[]E, E comparable](s1, s2 S) bool { + if len(s1) != len(s2) { + return false + } + for i := range s1 { + if s1[i] != s2[i] { + return false + } + } + return true +} + +// EqualFunc reports whether two slices are equal using an equality +// function on each pair of elements. If the lengths are different, +// EqualFunc returns false. Otherwise, the elements are compared in +// increasing index order, and the comparison stops at the first index +// for which eq returns false. +func EqualFunc[S1 ~[]E1, S2 ~[]E2, E1, E2 any](s1 S1, s2 S2, eq func(E1, E2) bool) bool { + if len(s1) != len(s2) { + return false + } + for i, v1 := range s1 { + v2 := s2[i] + if !eq(v1, v2) { + return false + } + } + return true +} + +// Compare compares the elements of s1 and s2, using [cmp.Compare] on each pair +// of elements. The elements are compared sequentially, starting at index 0, +// until one element is not equal to the other. +// The result of comparing the first non-matching elements is returned. +// If both slices are equal until one of them ends, the shorter slice is +// considered less than the longer one. +// The result is 0 if s1 == s2, -1 if s1 < s2, and +1 if s1 > s2. +func Compare[S ~[]E, E constraints.Ordered](s1, s2 S) int { + for i, v1 := range s1 { + if i >= len(s2) { + return +1 + } + v2 := s2[i] + if c := cmpCompare(v1, v2); c != 0 { + return c + } + } + if len(s1) < len(s2) { + return -1 + } + return 0 +} + +// CompareFunc is like [Compare] but uses a custom comparison function on each +// pair of elements. +// The result is the first non-zero result of cmp; if cmp always +// returns 0 the result is 0 if len(s1) == len(s2), -1 if len(s1) < len(s2), +// and +1 if len(s1) > len(s2). +func CompareFunc[S1 ~[]E1, S2 ~[]E2, E1, E2 any](s1 S1, s2 S2, cmp func(E1, E2) int) int { + for i, v1 := range s1 { + if i >= len(s2) { + return +1 + } + v2 := s2[i] + if c := cmp(v1, v2); c != 0 { + return c + } + } + if len(s1) < len(s2) { + return -1 + } + return 0 +} + +// Index returns the index of the first occurrence of v in s, +// or -1 if not present. +func Index[S ~[]E, E comparable](s S, v E) int { + for i := range s { + if v == s[i] { + return i + } + } + return -1 +} + +// IndexFunc returns the first index i satisfying f(s[i]), +// or -1 if none do. +func IndexFunc[S ~[]E, E any](s S, f func(E) bool) int { + for i := range s { + if f(s[i]) { + return i + } + } + return -1 +} + +// Contains reports whether v is present in s. +func Contains[S ~[]E, E comparable](s S, v E) bool { + return Index(s, v) >= 0 +} + +// ContainsFunc reports whether at least one +// element e of s satisfies f(e). +func ContainsFunc[S ~[]E, E any](s S, f func(E) bool) bool { + return IndexFunc(s, f) >= 0 +} + +// Insert inserts the values v... into s at index i, +// returning the modified slice. +// The elements at s[i:] are shifted up to make room. +// In the returned slice r, r[i] == v[0], +// and r[i+len(v)] == value originally at r[i]. +// Insert panics if i is out of range. +// This function is O(len(s) + len(v)). +func Insert[S ~[]E, E any](s S, i int, v ...E) S { + m := len(v) + if m == 0 { + return s + } + n := len(s) + if i == n { + return append(s, v...) + } + if n+m > cap(s) { + // Use append rather than make so that we bump the size of + // the slice up to the next storage class. + // This is what Grow does but we don't call Grow because + // that might copy the values twice. + s2 := append(s[:i], make(S, n+m-i)...) + copy(s2[i:], v) + copy(s2[i+m:], s[i:]) + return s2 + } + s = s[:n+m] + + // before: + // s: aaaaaaaabbbbccccccccdddd + // ^ ^ ^ ^ + // i i+m n n+m + // after: + // s: aaaaaaaavvvvbbbbcccccccc + // ^ ^ ^ ^ + // i i+m n n+m + // + // a are the values that don't move in s. + // v are the values copied in from v. + // b and c are the values from s that are shifted up in index. + // d are the values that get overwritten, never to be seen again. + + if !overlaps(v, s[i+m:]) { + // Easy case - v does not overlap either the c or d regions. + // (It might be in some of a or b, or elsewhere entirely.) + // The data we copy up doesn't write to v at all, so just do it. + + copy(s[i+m:], s[i:]) + + // Now we have + // s: aaaaaaaabbbbbbbbcccccccc + // ^ ^ ^ ^ + // i i+m n n+m + // Note the b values are duplicated. + + copy(s[i:], v) + + // Now we have + // s: aaaaaaaavvvvbbbbcccccccc + // ^ ^ ^ ^ + // i i+m n n+m + // That's the result we want. + return s + } + + // The hard case - v overlaps c or d. We can't just shift up + // the data because we'd move or clobber the values we're trying + // to insert. + // So instead, write v on top of d, then rotate. + copy(s[n:], v) + + // Now we have + // s: aaaaaaaabbbbccccccccvvvv + // ^ ^ ^ ^ + // i i+m n n+m + + rotateRight(s[i:], m) + + // Now we have + // s: aaaaaaaavvvvbbbbcccccccc + // ^ ^ ^ ^ + // i i+m n n+m + // That's the result we want. + return s +} + +// Delete removes the elements s[i:j] from s, returning the modified slice. +// Delete panics if s[i:j] is not a valid slice of s. +// Delete is O(len(s)-j), so if many items must be deleted, it is better to +// make a single call deleting them all together than to delete one at a time. +// Delete might not modify the elements s[len(s)-(j-i):len(s)]. If those +// elements contain pointers you might consider zeroing those elements so that +// objects they reference can be garbage collected. +func Delete[S ~[]E, E any](s S, i, j int) S { + _ = s[i:j] // bounds check + + return append(s[:i], s[j:]...) +} + +// DeleteFunc removes any elements from s for which del returns true, +// returning the modified slice. +// When DeleteFunc removes m elements, it might not modify the elements +// s[len(s)-m:len(s)]. If those elements contain pointers you might consider +// zeroing those elements so that objects they reference can be garbage +// collected. +func DeleteFunc[S ~[]E, E any](s S, del func(E) bool) S { + i := IndexFunc(s, del) + if i == -1 { + return s + } + // Don't start copying elements until we find one to delete. + for j := i + 1; j < len(s); j++ { + if v := s[j]; !del(v) { + s[i] = v + i++ + } + } + return s[:i] +} + +// Replace replaces the elements s[i:j] by the given v, and returns the +// modified slice. Replace panics if s[i:j] is not a valid slice of s. +func Replace[S ~[]E, E any](s S, i, j int, v ...E) S { + _ = s[i:j] // verify that i:j is a valid subslice + + if i == j { + return Insert(s, i, v...) + } + if j == len(s) { + return append(s[:i], v...) + } + + tot := len(s[:i]) + len(v) + len(s[j:]) + if tot > cap(s) { + // Too big to fit, allocate and copy over. + s2 := append(s[:i], make(S, tot-i)...) // See Insert + copy(s2[i:], v) + copy(s2[i+len(v):], s[j:]) + return s2 + } + + r := s[:tot] + + if i+len(v) <= j { + // Easy, as v fits in the deleted portion. + copy(r[i:], v) + if i+len(v) != j { + copy(r[i+len(v):], s[j:]) + } + return r + } + + // We are expanding (v is bigger than j-i). + // The situation is something like this: + // (example has i=4,j=8,len(s)=16,len(v)=6) + // s: aaaaxxxxbbbbbbbbyy + // ^ ^ ^ ^ + // i j len(s) tot + // a: prefix of s + // x: deleted range + // b: more of s + // y: area to expand into + + if !overlaps(r[i+len(v):], v) { + // Easy, as v is not clobbered by the first copy. + copy(r[i+len(v):], s[j:]) + copy(r[i:], v) + return r + } + + // This is a situation where we don't have a single place to which + // we can copy v. Parts of it need to go to two different places. + // We want to copy the prefix of v into y and the suffix into x, then + // rotate |y| spots to the right. + // + // v[2:] v[:2] + // | | + // s: aaaavvvvbbbbbbbbvv + // ^ ^ ^ ^ + // i j len(s) tot + // + // If either of those two destinations don't alias v, then we're good. + y := len(v) - (j - i) // length of y portion + + if !overlaps(r[i:j], v) { + copy(r[i:j], v[y:]) + copy(r[len(s):], v[:y]) + rotateRight(r[i:], y) + return r + } + if !overlaps(r[len(s):], v) { + copy(r[len(s):], v[:y]) + copy(r[i:j], v[y:]) + rotateRight(r[i:], y) + return r + } + + // Now we know that v overlaps both x and y. + // That means that the entirety of b is *inside* v. + // So we don't need to preserve b at all; instead we + // can copy v first, then copy the b part of v out of + // v to the right destination. + k := startIdx(v, s[j:]) + copy(r[i:], v) + copy(r[i+len(v):], r[i+k:]) + return r +} + +// Clone returns a copy of the slice. +// The elements are copied using assignment, so this is a shallow clone. +func Clone[S ~[]E, E any](s S) S { + // Preserve nil in case it matters. + if s == nil { + return nil + } + return append(S([]E{}), s...) +} + +// Compact replaces consecutive runs of equal elements with a single copy. +// This is like the uniq command found on Unix. +// Compact modifies the contents of the slice s and returns the modified slice, +// which may have a smaller length. +// When Compact discards m elements in total, it might not modify the elements +// s[len(s)-m:len(s)]. If those elements contain pointers you might consider +// zeroing those elements so that objects they reference can be garbage collected. +func Compact[S ~[]E, E comparable](s S) S { + if len(s) < 2 { + return s + } + i := 1 + for k := 1; k < len(s); k++ { + if s[k] != s[k-1] { + if i != k { + s[i] = s[k] + } + i++ + } + } + return s[:i] +} + +// CompactFunc is like [Compact] but uses an equality function to compare elements. +// For runs of elements that compare equal, CompactFunc keeps the first one. +func CompactFunc[S ~[]E, E any](s S, eq func(E, E) bool) S { + if len(s) < 2 { + return s + } + i := 1 + for k := 1; k < len(s); k++ { + if !eq(s[k], s[k-1]) { + if i != k { + s[i] = s[k] + } + i++ + } + } + return s[:i] +} + +// Grow increases the slice's capacity, if necessary, to guarantee space for +// another n elements. After Grow(n), at least n elements can be appended +// to the slice without another allocation. If n is negative or too large to +// allocate the memory, Grow panics. +func Grow[S ~[]E, E any](s S, n int) S { + if n < 0 { + panic("cannot be negative") + } + if n -= cap(s) - len(s); n > 0 { + // TODO(https://go.dev/issue/53888): Make using []E instead of S + // to workaround a compiler bug where the runtime.growslice optimization + // does not take effect. Revert when the compiler is fixed. + s = append([]E(s)[:cap(s)], make([]E, n)...)[:len(s)] + } + return s +} + +// Clip removes unused capacity from the slice, returning s[:len(s):len(s)]. +func Clip[S ~[]E, E any](s S) S { + return s[:len(s):len(s)] +} + +// Rotation algorithm explanation: +// +// rotate left by 2 +// start with +// 0123456789 +// split up like this +// 01 234567 89 +// swap first 2 and last 2 +// 89 234567 01 +// join first parts +// 89234567 01 +// recursively rotate first left part by 2 +// 23456789 01 +// join at the end +// 2345678901 +// +// rotate left by 8 +// start with +// 0123456789 +// split up like this +// 01 234567 89 +// swap first 2 and last 2 +// 89 234567 01 +// join last parts +// 89 23456701 +// recursively rotate second part left by 6 +// 89 01234567 +// join at the end +// 8901234567 + +// TODO: There are other rotate algorithms. +// This algorithm has the desirable property that it moves each element exactly twice. +// The triple-reverse algorithm is simpler and more cache friendly, but takes more writes. +// The follow-cycles algorithm can be 1-write but it is not very cache friendly. + +// rotateLeft rotates b left by n spaces. +// s_final[i] = s_orig[i+r], wrapping around. +func rotateLeft[E any](s []E, r int) { + for r != 0 && r != len(s) { + if r*2 <= len(s) { + swap(s[:r], s[len(s)-r:]) + s = s[:len(s)-r] + } else { + swap(s[:len(s)-r], s[r:]) + s, r = s[len(s)-r:], r*2-len(s) + } + } +} +func rotateRight[E any](s []E, r int) { + rotateLeft(s, len(s)-r) +} + +// swap swaps the contents of x and y. x and y must be equal length and disjoint. +func swap[E any](x, y []E) { + for i := 0; i < len(x); i++ { + x[i], y[i] = y[i], x[i] + } +} + +// overlaps reports whether the memory ranges a[0:len(a)] and b[0:len(b)] overlap. +func overlaps[E any](a, b []E) bool { + if len(a) == 0 || len(b) == 0 { + return false + } + elemSize := unsafe.Sizeof(a[0]) + if elemSize == 0 { + return false + } + // TODO: use a runtime/unsafe facility once one becomes available. See issue 12445. + // Also see crypto/internal/alias/alias.go:AnyOverlap + return uintptr(unsafe.Pointer(&a[0])) <= uintptr(unsafe.Pointer(&b[len(b)-1]))+(elemSize-1) && + uintptr(unsafe.Pointer(&b[0])) <= uintptr(unsafe.Pointer(&a[len(a)-1]))+(elemSize-1) +} + +// startIdx returns the index in haystack where the needle starts. +// prerequisite: the needle must be aliased entirely inside the haystack. +func startIdx[E any](haystack, needle []E) int { + p := &needle[0] + for i := range haystack { + if p == &haystack[i] { + return i + } + } + // TODO: what if the overlap is by a non-integral number of Es? + panic("needle not found") +} + +// Reverse reverses the elements of the slice in place. +func Reverse[S ~[]E, E any](s S) { + for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 { + s[i], s[j] = s[j], s[i] + } +} diff --git a/vendor/golang.org/x/exp/slices/sort.go b/vendor/golang.org/x/exp/slices/sort.go new file mode 100644 index 000000000..b67897f76 --- /dev/null +++ b/vendor/golang.org/x/exp/slices/sort.go @@ -0,0 +1,195 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:generate go run $GOROOT/src/sort/gen_sort_variants.go -exp + +package slices + +import ( + "math/bits" + + "golang.org/x/exp/constraints" +) + +// Sort sorts a slice of any ordered type in ascending order. +// When sorting floating-point numbers, NaNs are ordered before other values. +func Sort[S ~[]E, E constraints.Ordered](x S) { + n := len(x) + pdqsortOrdered(x, 0, n, bits.Len(uint(n))) +} + +// SortFunc sorts the slice x in ascending order as determined by the cmp +// function. This sort is not guaranteed to be stable. +// cmp(a, b) should return a negative number when a < b, a positive number when +// a > b and zero when a == b. +// +// SortFunc requires that cmp is a strict weak ordering. +// See https://en.wikipedia.org/wiki/Weak_ordering#Strict_weak_orderings. +func SortFunc[S ~[]E, E any](x S, cmp func(a, b E) int) { + n := len(x) + pdqsortCmpFunc(x, 0, n, bits.Len(uint(n)), cmp) +} + +// SortStableFunc sorts the slice x while keeping the original order of equal +// elements, using cmp to compare elements in the same way as [SortFunc]. +func SortStableFunc[S ~[]E, E any](x S, cmp func(a, b E) int) { + stableCmpFunc(x, len(x), cmp) +} + +// IsSorted reports whether x is sorted in ascending order. +func IsSorted[S ~[]E, E constraints.Ordered](x S) bool { + for i := len(x) - 1; i > 0; i-- { + if cmpLess(x[i], x[i-1]) { + return false + } + } + return true +} + +// IsSortedFunc reports whether x is sorted in ascending order, with cmp as the +// comparison function as defined by [SortFunc]. +func IsSortedFunc[S ~[]E, E any](x S, cmp func(a, b E) int) bool { + for i := len(x) - 1; i > 0; i-- { + if cmp(x[i], x[i-1]) < 0 { + return false + } + } + return true +} + +// Min returns the minimal value in x. It panics if x is empty. +// For floating-point numbers, Min propagates NaNs (any NaN value in x +// forces the output to be NaN). +func Min[S ~[]E, E constraints.Ordered](x S) E { + if len(x) < 1 { + panic("slices.Min: empty list") + } + m := x[0] + for i := 1; i < len(x); i++ { + m = min(m, x[i]) + } + return m +} + +// MinFunc returns the minimal value in x, using cmp to compare elements. +// It panics if x is empty. If there is more than one minimal element +// according to the cmp function, MinFunc returns the first one. +func MinFunc[S ~[]E, E any](x S, cmp func(a, b E) int) E { + if len(x) < 1 { + panic("slices.MinFunc: empty list") + } + m := x[0] + for i := 1; i < len(x); i++ { + if cmp(x[i], m) < 0 { + m = x[i] + } + } + return m +} + +// Max returns the maximal value in x. It panics if x is empty. +// For floating-point E, Max propagates NaNs (any NaN value in x +// forces the output to be NaN). +func Max[S ~[]E, E constraints.Ordered](x S) E { + if len(x) < 1 { + panic("slices.Max: empty list") + } + m := x[0] + for i := 1; i < len(x); i++ { + m = max(m, x[i]) + } + return m +} + +// MaxFunc returns the maximal value in x, using cmp to compare elements. +// It panics if x is empty. If there is more than one maximal element +// according to the cmp function, MaxFunc returns the first one. +func MaxFunc[S ~[]E, E any](x S, cmp func(a, b E) int) E { + if len(x) < 1 { + panic("slices.MaxFunc: empty list") + } + m := x[0] + for i := 1; i < len(x); i++ { + if cmp(x[i], m) > 0 { + m = x[i] + } + } + return m +} + +// BinarySearch searches for target in a sorted slice and returns the position +// where target is found, or the position where target would appear in the +// sort order; it also returns a bool saying whether the target is really found +// in the slice. The slice must be sorted in increasing order. +func BinarySearch[S ~[]E, E constraints.Ordered](x S, target E) (int, bool) { + // Inlining is faster than calling BinarySearchFunc with a lambda. + n := len(x) + // Define x[-1] < target and x[n] >= target. + // Invariant: x[i-1] < target, x[j] >= target. + i, j := 0, n + for i < j { + h := int(uint(i+j) >> 1) // avoid overflow when computing h + // i ≤ h < j + if cmpLess(x[h], target) { + i = h + 1 // preserves x[i-1] < target + } else { + j = h // preserves x[j] >= target + } + } + // i == j, x[i-1] < target, and x[j] (= x[i]) >= target => answer is i. + return i, i < n && (x[i] == target || (isNaN(x[i]) && isNaN(target))) +} + +// BinarySearchFunc works like [BinarySearch], but uses a custom comparison +// function. The slice must be sorted in increasing order, where "increasing" +// is defined by cmp. cmp should return 0 if the slice element matches +// the target, a negative number if the slice element precedes the target, +// or a positive number if the slice element follows the target. +// cmp must implement the same ordering as the slice, such that if +// cmp(a, t) < 0 and cmp(b, t) >= 0, then a must precede b in the slice. +func BinarySearchFunc[S ~[]E, E, T any](x S, target T, cmp func(E, T) int) (int, bool) { + n := len(x) + // Define cmp(x[-1], target) < 0 and cmp(x[n], target) >= 0 . + // Invariant: cmp(x[i - 1], target) < 0, cmp(x[j], target) >= 0. + i, j := 0, n + for i < j { + h := int(uint(i+j) >> 1) // avoid overflow when computing h + // i ≤ h < j + if cmp(x[h], target) < 0 { + i = h + 1 // preserves cmp(x[i - 1], target) < 0 + } else { + j = h // preserves cmp(x[j], target) >= 0 + } + } + // i == j, cmp(x[i-1], target) < 0, and cmp(x[j], target) (= cmp(x[i], target)) >= 0 => answer is i. + return i, i < n && cmp(x[i], target) == 0 +} + +type sortedHint int // hint for pdqsort when choosing the pivot + +const ( + unknownHint sortedHint = iota + increasingHint + decreasingHint +) + +// xorshift paper: https://www.jstatsoft.org/article/view/v008i14/xorshift.pdf +type xorshift uint64 + +func (r *xorshift) Next() uint64 { + *r ^= *r << 13 + *r ^= *r >> 17 + *r ^= *r << 5 + return uint64(*r) +} + +func nextPowerOfTwo(length int) uint { + return 1 << bits.Len(uint(length)) +} + +// isNaN reports whether x is a NaN without requiring the math package. +// This will always return false if T is not floating-point. +func isNaN[T constraints.Ordered](x T) bool { + return x != x +} diff --git a/vendor/golang.org/x/exp/slices/zsortanyfunc.go b/vendor/golang.org/x/exp/slices/zsortanyfunc.go new file mode 100644 index 000000000..06f2c7a24 --- /dev/null +++ b/vendor/golang.org/x/exp/slices/zsortanyfunc.go @@ -0,0 +1,479 @@ +// Code generated by gen_sort_variants.go; DO NOT EDIT. + +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package slices + +// insertionSortCmpFunc sorts data[a:b] using insertion sort. +func insertionSortCmpFunc[E any](data []E, a, b int, cmp func(a, b E) int) { + for i := a + 1; i < b; i++ { + for j := i; j > a && (cmp(data[j], data[j-1]) < 0); j-- { + data[j], data[j-1] = data[j-1], data[j] + } + } +} + +// siftDownCmpFunc implements the heap property on data[lo:hi]. +// first is an offset into the array where the root of the heap lies. +func siftDownCmpFunc[E any](data []E, lo, hi, first int, cmp func(a, b E) int) { + root := lo + for { + child := 2*root + 1 + if child >= hi { + break + } + if child+1 < hi && (cmp(data[first+child], data[first+child+1]) < 0) { + child++ + } + if !(cmp(data[first+root], data[first+child]) < 0) { + return + } + data[first+root], data[first+child] = data[first+child], data[first+root] + root = child + } +} + +func heapSortCmpFunc[E any](data []E, a, b int, cmp func(a, b E) int) { + first := a + lo := 0 + hi := b - a + + // Build heap with greatest element at top. + for i := (hi - 1) / 2; i >= 0; i-- { + siftDownCmpFunc(data, i, hi, first, cmp) + } + + // Pop elements, largest first, into end of data. + for i := hi - 1; i >= 0; i-- { + data[first], data[first+i] = data[first+i], data[first] + siftDownCmpFunc(data, lo, i, first, cmp) + } +} + +// pdqsortCmpFunc sorts data[a:b]. +// The algorithm based on pattern-defeating quicksort(pdqsort), but without the optimizations from BlockQuicksort. +// pdqsort paper: https://arxiv.org/pdf/2106.05123.pdf +// C++ implementation: https://github.com/orlp/pdqsort +// Rust implementation: https://docs.rs/pdqsort/latest/pdqsort/ +// limit is the number of allowed bad (very unbalanced) pivots before falling back to heapsort. +func pdqsortCmpFunc[E any](data []E, a, b, limit int, cmp func(a, b E) int) { + const maxInsertion = 12 + + var ( + wasBalanced = true // whether the last partitioning was reasonably balanced + wasPartitioned = true // whether the slice was already partitioned + ) + + for { + length := b - a + + if length <= maxInsertion { + insertionSortCmpFunc(data, a, b, cmp) + return + } + + // Fall back to heapsort if too many bad choices were made. + if limit == 0 { + heapSortCmpFunc(data, a, b, cmp) + return + } + + // If the last partitioning was imbalanced, we need to breaking patterns. + if !wasBalanced { + breakPatternsCmpFunc(data, a, b, cmp) + limit-- + } + + pivot, hint := choosePivotCmpFunc(data, a, b, cmp) + if hint == decreasingHint { + reverseRangeCmpFunc(data, a, b, cmp) + // The chosen pivot was pivot-a elements after the start of the array. + // After reversing it is pivot-a elements before the end of the array. + // The idea came from Rust's implementation. + pivot = (b - 1) - (pivot - a) + hint = increasingHint + } + + // The slice is likely already sorted. + if wasBalanced && wasPartitioned && hint == increasingHint { + if partialInsertionSortCmpFunc(data, a, b, cmp) { + return + } + } + + // Probably the slice contains many duplicate elements, partition the slice into + // elements equal to and elements greater than the pivot. + if a > 0 && !(cmp(data[a-1], data[pivot]) < 0) { + mid := partitionEqualCmpFunc(data, a, b, pivot, cmp) + a = mid + continue + } + + mid, alreadyPartitioned := partitionCmpFunc(data, a, b, pivot, cmp) + wasPartitioned = alreadyPartitioned + + leftLen, rightLen := mid-a, b-mid + balanceThreshold := length / 8 + if leftLen < rightLen { + wasBalanced = leftLen >= balanceThreshold + pdqsortCmpFunc(data, a, mid, limit, cmp) + a = mid + 1 + } else { + wasBalanced = rightLen >= balanceThreshold + pdqsortCmpFunc(data, mid+1, b, limit, cmp) + b = mid + } + } +} + +// partitionCmpFunc does one quicksort partition. +// Let p = data[pivot] +// Moves elements in data[a:b] around, so that data[i]

=p for inewpivot. +// On return, data[newpivot] = p +func partitionCmpFunc[E any](data []E, a, b, pivot int, cmp func(a, b E) int) (newpivot int, alreadyPartitioned bool) { + data[a], data[pivot] = data[pivot], data[a] + i, j := a+1, b-1 // i and j are inclusive of the elements remaining to be partitioned + + for i <= j && (cmp(data[i], data[a]) < 0) { + i++ + } + for i <= j && !(cmp(data[j], data[a]) < 0) { + j-- + } + if i > j { + data[j], data[a] = data[a], data[j] + return j, true + } + data[i], data[j] = data[j], data[i] + i++ + j-- + + for { + for i <= j && (cmp(data[i], data[a]) < 0) { + i++ + } + for i <= j && !(cmp(data[j], data[a]) < 0) { + j-- + } + if i > j { + break + } + data[i], data[j] = data[j], data[i] + i++ + j-- + } + data[j], data[a] = data[a], data[j] + return j, false +} + +// partitionEqualCmpFunc partitions data[a:b] into elements equal to data[pivot] followed by elements greater than data[pivot]. +// It assumed that data[a:b] does not contain elements smaller than the data[pivot]. +func partitionEqualCmpFunc[E any](data []E, a, b, pivot int, cmp func(a, b E) int) (newpivot int) { + data[a], data[pivot] = data[pivot], data[a] + i, j := a+1, b-1 // i and j are inclusive of the elements remaining to be partitioned + + for { + for i <= j && !(cmp(data[a], data[i]) < 0) { + i++ + } + for i <= j && (cmp(data[a], data[j]) < 0) { + j-- + } + if i > j { + break + } + data[i], data[j] = data[j], data[i] + i++ + j-- + } + return i +} + +// partialInsertionSortCmpFunc partially sorts a slice, returns true if the slice is sorted at the end. +func partialInsertionSortCmpFunc[E any](data []E, a, b int, cmp func(a, b E) int) bool { + const ( + maxSteps = 5 // maximum number of adjacent out-of-order pairs that will get shifted + shortestShifting = 50 // don't shift any elements on short arrays + ) + i := a + 1 + for j := 0; j < maxSteps; j++ { + for i < b && !(cmp(data[i], data[i-1]) < 0) { + i++ + } + + if i == b { + return true + } + + if b-a < shortestShifting { + return false + } + + data[i], data[i-1] = data[i-1], data[i] + + // Shift the smaller one to the left. + if i-a >= 2 { + for j := i - 1; j >= 1; j-- { + if !(cmp(data[j], data[j-1]) < 0) { + break + } + data[j], data[j-1] = data[j-1], data[j] + } + } + // Shift the greater one to the right. + if b-i >= 2 { + for j := i + 1; j < b; j++ { + if !(cmp(data[j], data[j-1]) < 0) { + break + } + data[j], data[j-1] = data[j-1], data[j] + } + } + } + return false +} + +// breakPatternsCmpFunc scatters some elements around in an attempt to break some patterns +// that might cause imbalanced partitions in quicksort. +func breakPatternsCmpFunc[E any](data []E, a, b int, cmp func(a, b E) int) { + length := b - a + if length >= 8 { + random := xorshift(length) + modulus := nextPowerOfTwo(length) + + for idx := a + (length/4)*2 - 1; idx <= a+(length/4)*2+1; idx++ { + other := int(uint(random.Next()) & (modulus - 1)) + if other >= length { + other -= length + } + data[idx], data[a+other] = data[a+other], data[idx] + } + } +} + +// choosePivotCmpFunc chooses a pivot in data[a:b]. +// +// [0,8): chooses a static pivot. +// [8,shortestNinther): uses the simple median-of-three method. +// [shortestNinther,∞): uses the Tukey ninther method. +func choosePivotCmpFunc[E any](data []E, a, b int, cmp func(a, b E) int) (pivot int, hint sortedHint) { + const ( + shortestNinther = 50 + maxSwaps = 4 * 3 + ) + + l := b - a + + var ( + swaps int + i = a + l/4*1 + j = a + l/4*2 + k = a + l/4*3 + ) + + if l >= 8 { + if l >= shortestNinther { + // Tukey ninther method, the idea came from Rust's implementation. + i = medianAdjacentCmpFunc(data, i, &swaps, cmp) + j = medianAdjacentCmpFunc(data, j, &swaps, cmp) + k = medianAdjacentCmpFunc(data, k, &swaps, cmp) + } + // Find the median among i, j, k and stores it into j. + j = medianCmpFunc(data, i, j, k, &swaps, cmp) + } + + switch swaps { + case 0: + return j, increasingHint + case maxSwaps: + return j, decreasingHint + default: + return j, unknownHint + } +} + +// order2CmpFunc returns x,y where data[x] <= data[y], where x,y=a,b or x,y=b,a. +func order2CmpFunc[E any](data []E, a, b int, swaps *int, cmp func(a, b E) int) (int, int) { + if cmp(data[b], data[a]) < 0 { + *swaps++ + return b, a + } + return a, b +} + +// medianCmpFunc returns x where data[x] is the median of data[a],data[b],data[c], where x is a, b, or c. +func medianCmpFunc[E any](data []E, a, b, c int, swaps *int, cmp func(a, b E) int) int { + a, b = order2CmpFunc(data, a, b, swaps, cmp) + b, c = order2CmpFunc(data, b, c, swaps, cmp) + a, b = order2CmpFunc(data, a, b, swaps, cmp) + return b +} + +// medianAdjacentCmpFunc finds the median of data[a - 1], data[a], data[a + 1] and stores the index into a. +func medianAdjacentCmpFunc[E any](data []E, a int, swaps *int, cmp func(a, b E) int) int { + return medianCmpFunc(data, a-1, a, a+1, swaps, cmp) +} + +func reverseRangeCmpFunc[E any](data []E, a, b int, cmp func(a, b E) int) { + i := a + j := b - 1 + for i < j { + data[i], data[j] = data[j], data[i] + i++ + j-- + } +} + +func swapRangeCmpFunc[E any](data []E, a, b, n int, cmp func(a, b E) int) { + for i := 0; i < n; i++ { + data[a+i], data[b+i] = data[b+i], data[a+i] + } +} + +func stableCmpFunc[E any](data []E, n int, cmp func(a, b E) int) { + blockSize := 20 // must be > 0 + a, b := 0, blockSize + for b <= n { + insertionSortCmpFunc(data, a, b, cmp) + a = b + b += blockSize + } + insertionSortCmpFunc(data, a, n, cmp) + + for blockSize < n { + a, b = 0, 2*blockSize + for b <= n { + symMergeCmpFunc(data, a, a+blockSize, b, cmp) + a = b + b += 2 * blockSize + } + if m := a + blockSize; m < n { + symMergeCmpFunc(data, a, m, n, cmp) + } + blockSize *= 2 + } +} + +// symMergeCmpFunc merges the two sorted subsequences data[a:m] and data[m:b] using +// the SymMerge algorithm from Pok-Son Kim and Arne Kutzner, "Stable Minimum +// Storage Merging by Symmetric Comparisons", in Susanne Albers and Tomasz +// Radzik, editors, Algorithms - ESA 2004, volume 3221 of Lecture Notes in +// Computer Science, pages 714-723. Springer, 2004. +// +// Let M = m-a and N = b-n. Wolog M < N. +// The recursion depth is bound by ceil(log(N+M)). +// The algorithm needs O(M*log(N/M + 1)) calls to data.Less. +// The algorithm needs O((M+N)*log(M)) calls to data.Swap. +// +// The paper gives O((M+N)*log(M)) as the number of assignments assuming a +// rotation algorithm which uses O(M+N+gcd(M+N)) assignments. The argumentation +// in the paper carries through for Swap operations, especially as the block +// swapping rotate uses only O(M+N) Swaps. +// +// symMerge assumes non-degenerate arguments: a < m && m < b. +// Having the caller check this condition eliminates many leaf recursion calls, +// which improves performance. +func symMergeCmpFunc[E any](data []E, a, m, b int, cmp func(a, b E) int) { + // Avoid unnecessary recursions of symMerge + // by direct insertion of data[a] into data[m:b] + // if data[a:m] only contains one element. + if m-a == 1 { + // Use binary search to find the lowest index i + // such that data[i] >= data[a] for m <= i < b. + // Exit the search loop with i == b in case no such index exists. + i := m + j := b + for i < j { + h := int(uint(i+j) >> 1) + if cmp(data[h], data[a]) < 0 { + i = h + 1 + } else { + j = h + } + } + // Swap values until data[a] reaches the position before i. + for k := a; k < i-1; k++ { + data[k], data[k+1] = data[k+1], data[k] + } + return + } + + // Avoid unnecessary recursions of symMerge + // by direct insertion of data[m] into data[a:m] + // if data[m:b] only contains one element. + if b-m == 1 { + // Use binary search to find the lowest index i + // such that data[i] > data[m] for a <= i < m. + // Exit the search loop with i == m in case no such index exists. + i := a + j := m + for i < j { + h := int(uint(i+j) >> 1) + if !(cmp(data[m], data[h]) < 0) { + i = h + 1 + } else { + j = h + } + } + // Swap values until data[m] reaches the position i. + for k := m; k > i; k-- { + data[k], data[k-1] = data[k-1], data[k] + } + return + } + + mid := int(uint(a+b) >> 1) + n := mid + m + var start, r int + if m > mid { + start = n - b + r = mid + } else { + start = a + r = m + } + p := n - 1 + + for start < r { + c := int(uint(start+r) >> 1) + if !(cmp(data[p-c], data[c]) < 0) { + start = c + 1 + } else { + r = c + } + } + + end := n - start + if start < m && m < end { + rotateCmpFunc(data, start, m, end, cmp) + } + if a < start && start < mid { + symMergeCmpFunc(data, a, start, mid, cmp) + } + if mid < end && end < b { + symMergeCmpFunc(data, mid, end, b, cmp) + } +} + +// rotateCmpFunc rotates two consecutive blocks u = data[a:m] and v = data[m:b] in data: +// Data of the form 'x u v y' is changed to 'x v u y'. +// rotate performs at most b-a many calls to data.Swap, +// and it assumes non-degenerate arguments: a < m && m < b. +func rotateCmpFunc[E any](data []E, a, m, b int, cmp func(a, b E) int) { + i := m - a + j := b - m + + for i != j { + if i > j { + swapRangeCmpFunc(data, m-i, m, j, cmp) + i -= j + } else { + swapRangeCmpFunc(data, m-i, m+j-i, i, cmp) + j -= i + } + } + // i == j + swapRangeCmpFunc(data, m-i, m, i, cmp) +} diff --git a/vendor/golang.org/x/exp/slices/zsortordered.go b/vendor/golang.org/x/exp/slices/zsortordered.go new file mode 100644 index 000000000..99b47c398 --- /dev/null +++ b/vendor/golang.org/x/exp/slices/zsortordered.go @@ -0,0 +1,481 @@ +// Code generated by gen_sort_variants.go; DO NOT EDIT. + +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package slices + +import "golang.org/x/exp/constraints" + +// insertionSortOrdered sorts data[a:b] using insertion sort. +func insertionSortOrdered[E constraints.Ordered](data []E, a, b int) { + for i := a + 1; i < b; i++ { + for j := i; j > a && cmpLess(data[j], data[j-1]); j-- { + data[j], data[j-1] = data[j-1], data[j] + } + } +} + +// siftDownOrdered implements the heap property on data[lo:hi]. +// first is an offset into the array where the root of the heap lies. +func siftDownOrdered[E constraints.Ordered](data []E, lo, hi, first int) { + root := lo + for { + child := 2*root + 1 + if child >= hi { + break + } + if child+1 < hi && cmpLess(data[first+child], data[first+child+1]) { + child++ + } + if !cmpLess(data[first+root], data[first+child]) { + return + } + data[first+root], data[first+child] = data[first+child], data[first+root] + root = child + } +} + +func heapSortOrdered[E constraints.Ordered](data []E, a, b int) { + first := a + lo := 0 + hi := b - a + + // Build heap with greatest element at top. + for i := (hi - 1) / 2; i >= 0; i-- { + siftDownOrdered(data, i, hi, first) + } + + // Pop elements, largest first, into end of data. + for i := hi - 1; i >= 0; i-- { + data[first], data[first+i] = data[first+i], data[first] + siftDownOrdered(data, lo, i, first) + } +} + +// pdqsortOrdered sorts data[a:b]. +// The algorithm based on pattern-defeating quicksort(pdqsort), but without the optimizations from BlockQuicksort. +// pdqsort paper: https://arxiv.org/pdf/2106.05123.pdf +// C++ implementation: https://github.com/orlp/pdqsort +// Rust implementation: https://docs.rs/pdqsort/latest/pdqsort/ +// limit is the number of allowed bad (very unbalanced) pivots before falling back to heapsort. +func pdqsortOrdered[E constraints.Ordered](data []E, a, b, limit int) { + const maxInsertion = 12 + + var ( + wasBalanced = true // whether the last partitioning was reasonably balanced + wasPartitioned = true // whether the slice was already partitioned + ) + + for { + length := b - a + + if length <= maxInsertion { + insertionSortOrdered(data, a, b) + return + } + + // Fall back to heapsort if too many bad choices were made. + if limit == 0 { + heapSortOrdered(data, a, b) + return + } + + // If the last partitioning was imbalanced, we need to breaking patterns. + if !wasBalanced { + breakPatternsOrdered(data, a, b) + limit-- + } + + pivot, hint := choosePivotOrdered(data, a, b) + if hint == decreasingHint { + reverseRangeOrdered(data, a, b) + // The chosen pivot was pivot-a elements after the start of the array. + // After reversing it is pivot-a elements before the end of the array. + // The idea came from Rust's implementation. + pivot = (b - 1) - (pivot - a) + hint = increasingHint + } + + // The slice is likely already sorted. + if wasBalanced && wasPartitioned && hint == increasingHint { + if partialInsertionSortOrdered(data, a, b) { + return + } + } + + // Probably the slice contains many duplicate elements, partition the slice into + // elements equal to and elements greater than the pivot. + if a > 0 && !cmpLess(data[a-1], data[pivot]) { + mid := partitionEqualOrdered(data, a, b, pivot) + a = mid + continue + } + + mid, alreadyPartitioned := partitionOrdered(data, a, b, pivot) + wasPartitioned = alreadyPartitioned + + leftLen, rightLen := mid-a, b-mid + balanceThreshold := length / 8 + if leftLen < rightLen { + wasBalanced = leftLen >= balanceThreshold + pdqsortOrdered(data, a, mid, limit) + a = mid + 1 + } else { + wasBalanced = rightLen >= balanceThreshold + pdqsortOrdered(data, mid+1, b, limit) + b = mid + } + } +} + +// partitionOrdered does one quicksort partition. +// Let p = data[pivot] +// Moves elements in data[a:b] around, so that data[i]

=p for inewpivot. +// On return, data[newpivot] = p +func partitionOrdered[E constraints.Ordered](data []E, a, b, pivot int) (newpivot int, alreadyPartitioned bool) { + data[a], data[pivot] = data[pivot], data[a] + i, j := a+1, b-1 // i and j are inclusive of the elements remaining to be partitioned + + for i <= j && cmpLess(data[i], data[a]) { + i++ + } + for i <= j && !cmpLess(data[j], data[a]) { + j-- + } + if i > j { + data[j], data[a] = data[a], data[j] + return j, true + } + data[i], data[j] = data[j], data[i] + i++ + j-- + + for { + for i <= j && cmpLess(data[i], data[a]) { + i++ + } + for i <= j && !cmpLess(data[j], data[a]) { + j-- + } + if i > j { + break + } + data[i], data[j] = data[j], data[i] + i++ + j-- + } + data[j], data[a] = data[a], data[j] + return j, false +} + +// partitionEqualOrdered partitions data[a:b] into elements equal to data[pivot] followed by elements greater than data[pivot]. +// It assumed that data[a:b] does not contain elements smaller than the data[pivot]. +func partitionEqualOrdered[E constraints.Ordered](data []E, a, b, pivot int) (newpivot int) { + data[a], data[pivot] = data[pivot], data[a] + i, j := a+1, b-1 // i and j are inclusive of the elements remaining to be partitioned + + for { + for i <= j && !cmpLess(data[a], data[i]) { + i++ + } + for i <= j && cmpLess(data[a], data[j]) { + j-- + } + if i > j { + break + } + data[i], data[j] = data[j], data[i] + i++ + j-- + } + return i +} + +// partialInsertionSortOrdered partially sorts a slice, returns true if the slice is sorted at the end. +func partialInsertionSortOrdered[E constraints.Ordered](data []E, a, b int) bool { + const ( + maxSteps = 5 // maximum number of adjacent out-of-order pairs that will get shifted + shortestShifting = 50 // don't shift any elements on short arrays + ) + i := a + 1 + for j := 0; j < maxSteps; j++ { + for i < b && !cmpLess(data[i], data[i-1]) { + i++ + } + + if i == b { + return true + } + + if b-a < shortestShifting { + return false + } + + data[i], data[i-1] = data[i-1], data[i] + + // Shift the smaller one to the left. + if i-a >= 2 { + for j := i - 1; j >= 1; j-- { + if !cmpLess(data[j], data[j-1]) { + break + } + data[j], data[j-1] = data[j-1], data[j] + } + } + // Shift the greater one to the right. + if b-i >= 2 { + for j := i + 1; j < b; j++ { + if !cmpLess(data[j], data[j-1]) { + break + } + data[j], data[j-1] = data[j-1], data[j] + } + } + } + return false +} + +// breakPatternsOrdered scatters some elements around in an attempt to break some patterns +// that might cause imbalanced partitions in quicksort. +func breakPatternsOrdered[E constraints.Ordered](data []E, a, b int) { + length := b - a + if length >= 8 { + random := xorshift(length) + modulus := nextPowerOfTwo(length) + + for idx := a + (length/4)*2 - 1; idx <= a+(length/4)*2+1; idx++ { + other := int(uint(random.Next()) & (modulus - 1)) + if other >= length { + other -= length + } + data[idx], data[a+other] = data[a+other], data[idx] + } + } +} + +// choosePivotOrdered chooses a pivot in data[a:b]. +// +// [0,8): chooses a static pivot. +// [8,shortestNinther): uses the simple median-of-three method. +// [shortestNinther,∞): uses the Tukey ninther method. +func choosePivotOrdered[E constraints.Ordered](data []E, a, b int) (pivot int, hint sortedHint) { + const ( + shortestNinther = 50 + maxSwaps = 4 * 3 + ) + + l := b - a + + var ( + swaps int + i = a + l/4*1 + j = a + l/4*2 + k = a + l/4*3 + ) + + if l >= 8 { + if l >= shortestNinther { + // Tukey ninther method, the idea came from Rust's implementation. + i = medianAdjacentOrdered(data, i, &swaps) + j = medianAdjacentOrdered(data, j, &swaps) + k = medianAdjacentOrdered(data, k, &swaps) + } + // Find the median among i, j, k and stores it into j. + j = medianOrdered(data, i, j, k, &swaps) + } + + switch swaps { + case 0: + return j, increasingHint + case maxSwaps: + return j, decreasingHint + default: + return j, unknownHint + } +} + +// order2Ordered returns x,y where data[x] <= data[y], where x,y=a,b or x,y=b,a. +func order2Ordered[E constraints.Ordered](data []E, a, b int, swaps *int) (int, int) { + if cmpLess(data[b], data[a]) { + *swaps++ + return b, a + } + return a, b +} + +// medianOrdered returns x where data[x] is the median of data[a],data[b],data[c], where x is a, b, or c. +func medianOrdered[E constraints.Ordered](data []E, a, b, c int, swaps *int) int { + a, b = order2Ordered(data, a, b, swaps) + b, c = order2Ordered(data, b, c, swaps) + a, b = order2Ordered(data, a, b, swaps) + return b +} + +// medianAdjacentOrdered finds the median of data[a - 1], data[a], data[a + 1] and stores the index into a. +func medianAdjacentOrdered[E constraints.Ordered](data []E, a int, swaps *int) int { + return medianOrdered(data, a-1, a, a+1, swaps) +} + +func reverseRangeOrdered[E constraints.Ordered](data []E, a, b int) { + i := a + j := b - 1 + for i < j { + data[i], data[j] = data[j], data[i] + i++ + j-- + } +} + +func swapRangeOrdered[E constraints.Ordered](data []E, a, b, n int) { + for i := 0; i < n; i++ { + data[a+i], data[b+i] = data[b+i], data[a+i] + } +} + +func stableOrdered[E constraints.Ordered](data []E, n int) { + blockSize := 20 // must be > 0 + a, b := 0, blockSize + for b <= n { + insertionSortOrdered(data, a, b) + a = b + b += blockSize + } + insertionSortOrdered(data, a, n) + + for blockSize < n { + a, b = 0, 2*blockSize + for b <= n { + symMergeOrdered(data, a, a+blockSize, b) + a = b + b += 2 * blockSize + } + if m := a + blockSize; m < n { + symMergeOrdered(data, a, m, n) + } + blockSize *= 2 + } +} + +// symMergeOrdered merges the two sorted subsequences data[a:m] and data[m:b] using +// the SymMerge algorithm from Pok-Son Kim and Arne Kutzner, "Stable Minimum +// Storage Merging by Symmetric Comparisons", in Susanne Albers and Tomasz +// Radzik, editors, Algorithms - ESA 2004, volume 3221 of Lecture Notes in +// Computer Science, pages 714-723. Springer, 2004. +// +// Let M = m-a and N = b-n. Wolog M < N. +// The recursion depth is bound by ceil(log(N+M)). +// The algorithm needs O(M*log(N/M + 1)) calls to data.Less. +// The algorithm needs O((M+N)*log(M)) calls to data.Swap. +// +// The paper gives O((M+N)*log(M)) as the number of assignments assuming a +// rotation algorithm which uses O(M+N+gcd(M+N)) assignments. The argumentation +// in the paper carries through for Swap operations, especially as the block +// swapping rotate uses only O(M+N) Swaps. +// +// symMerge assumes non-degenerate arguments: a < m && m < b. +// Having the caller check this condition eliminates many leaf recursion calls, +// which improves performance. +func symMergeOrdered[E constraints.Ordered](data []E, a, m, b int) { + // Avoid unnecessary recursions of symMerge + // by direct insertion of data[a] into data[m:b] + // if data[a:m] only contains one element. + if m-a == 1 { + // Use binary search to find the lowest index i + // such that data[i] >= data[a] for m <= i < b. + // Exit the search loop with i == b in case no such index exists. + i := m + j := b + for i < j { + h := int(uint(i+j) >> 1) + if cmpLess(data[h], data[a]) { + i = h + 1 + } else { + j = h + } + } + // Swap values until data[a] reaches the position before i. + for k := a; k < i-1; k++ { + data[k], data[k+1] = data[k+1], data[k] + } + return + } + + // Avoid unnecessary recursions of symMerge + // by direct insertion of data[m] into data[a:m] + // if data[m:b] only contains one element. + if b-m == 1 { + // Use binary search to find the lowest index i + // such that data[i] > data[m] for a <= i < m. + // Exit the search loop with i == m in case no such index exists. + i := a + j := m + for i < j { + h := int(uint(i+j) >> 1) + if !cmpLess(data[m], data[h]) { + i = h + 1 + } else { + j = h + } + } + // Swap values until data[m] reaches the position i. + for k := m; k > i; k-- { + data[k], data[k-1] = data[k-1], data[k] + } + return + } + + mid := int(uint(a+b) >> 1) + n := mid + m + var start, r int + if m > mid { + start = n - b + r = mid + } else { + start = a + r = m + } + p := n - 1 + + for start < r { + c := int(uint(start+r) >> 1) + if !cmpLess(data[p-c], data[c]) { + start = c + 1 + } else { + r = c + } + } + + end := n - start + if start < m && m < end { + rotateOrdered(data, start, m, end) + } + if a < start && start < mid { + symMergeOrdered(data, a, start, mid) + } + if mid < end && end < b { + symMergeOrdered(data, mid, end, b) + } +} + +// rotateOrdered rotates two consecutive blocks u = data[a:m] and v = data[m:b] in data: +// Data of the form 'x u v y' is changed to 'x v u y'. +// rotate performs at most b-a many calls to data.Swap, +// and it assumes non-degenerate arguments: a < m && m < b. +func rotateOrdered[E constraints.Ordered](data []E, a, m, b int) { + i := m - a + j := b - m + + for i != j { + if i > j { + swapRangeOrdered(data, m-i, m, j) + i -= j + } else { + swapRangeOrdered(data, m-i, m+j-i, i) + j -= i + } + } + // i == j + swapRangeOrdered(data, m-i, m, i) +} diff --git a/vendor/golang.org/x/tools/go/ast/inspector/inspector.go b/vendor/golang.org/x/tools/go/ast/inspector/inspector.go index 3fbfebf36..1fc1de0bd 100644 --- a/vendor/golang.org/x/tools/go/ast/inspector/inspector.go +++ b/vendor/golang.org/x/tools/go/ast/inspector/inspector.go @@ -64,8 +64,9 @@ type event struct { // depth-first order. It calls f(n) for each node n before it visits // n's children. // +// The complete traversal sequence is determined by ast.Inspect. // The types argument, if non-empty, enables type-based filtering of -// events. The function f if is called only for nodes whose type +// events. The function f is called only for nodes whose type // matches an element of the types slice. func (in *Inspector) Preorder(types []ast.Node, f func(ast.Node)) { // Because it avoids postorder calls to f, and the pruning @@ -97,6 +98,7 @@ func (in *Inspector) Preorder(types []ast.Node, f func(ast.Node)) { // of the non-nil children of the node, followed by a call of // f(n, false). // +// The complete traversal sequence is determined by ast.Inspect. // The types argument, if non-empty, enables type-based filtering of // events. The function f if is called only for nodes whose type // matches an element of the types slice. diff --git a/vendor/golang.org/x/tools/internal/typeparams/common.go b/vendor/golang.org/x/tools/internal/typeparams/common.go index 25a1426d3..d0d0649fe 100644 --- a/vendor/golang.org/x/tools/internal/typeparams/common.go +++ b/vendor/golang.org/x/tools/internal/typeparams/common.go @@ -23,6 +23,7 @@ package typeparams import ( + "fmt" "go/ast" "go/token" "go/types" @@ -87,7 +88,6 @@ func IsTypeParam(t types.Type) bool { func OriginMethod(fn *types.Func) *types.Func { recv := fn.Type().(*types.Signature).Recv() if recv == nil { - return fn } base := recv.Type() @@ -106,6 +106,31 @@ func OriginMethod(fn *types.Func) *types.Func { } orig := NamedTypeOrigin(named) gfn, _, _ := types.LookupFieldOrMethod(orig, true, fn.Pkg(), fn.Name()) + + // This is a fix for a gopls crash (#60628) due to a go/types bug (#60634). In: + // package p + // type T *int + // func (*T) f() {} + // LookupFieldOrMethod(T, true, p, f)=nil, but NewMethodSet(*T)={(*T).f}. + // Here we make them consistent by force. + // (The go/types bug is general, but this workaround is reached only + // for generic T thanks to the early return above.) + if gfn == nil { + mset := types.NewMethodSet(types.NewPointer(orig)) + for i := 0; i < mset.Len(); i++ { + m := mset.At(i) + if m.Obj().Id() == fn.Id() { + gfn = m.Obj() + break + } + } + } + + // In golang/go#61196, we observe another crash, this time inexplicable. + if gfn == nil { + panic(fmt.Sprintf("missing origin method for %s.%s; named == origin: %t, named.NumMethods(): %d, origin.NumMethods(): %d", named, fn, named == orig, named.NumMethods(), orig.NumMethods())) + } + return gfn.(*types.Func) } diff --git a/vendor/golang.org/x/tools/internal/typeparams/typeparams_go117.go b/vendor/golang.org/x/tools/internal/typeparams/typeparams_go117.go index b4788978f..7ed86e171 100644 --- a/vendor/golang.org/x/tools/internal/typeparams/typeparams_go117.go +++ b/vendor/golang.org/x/tools/internal/typeparams/typeparams_go117.go @@ -129,7 +129,7 @@ func NamedTypeArgs(*types.Named) *TypeList { } // NamedTypeOrigin is the identity method at this Go version. -func NamedTypeOrigin(named *types.Named) types.Type { +func NamedTypeOrigin(named *types.Named) *types.Named { return named } diff --git a/vendor/golang.org/x/tools/internal/typeparams/typeparams_go118.go b/vendor/golang.org/x/tools/internal/typeparams/typeparams_go118.go index 114a36b86..cf301af1d 100644 --- a/vendor/golang.org/x/tools/internal/typeparams/typeparams_go118.go +++ b/vendor/golang.org/x/tools/internal/typeparams/typeparams_go118.go @@ -103,7 +103,7 @@ func NamedTypeArgs(named *types.Named) *TypeList { } // NamedTypeOrigin returns named.Orig(). -func NamedTypeOrigin(named *types.Named) types.Type { +func NamedTypeOrigin(named *types.Named) *types.Named { return named.Origin() } diff --git a/vendor/modules.txt b/vendor/modules.txt index fe08aa0a2..70edf497d 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -338,6 +338,10 @@ golang.org/x/crypto/internal/alias golang.org/x/crypto/internal/poly1305 golang.org/x/crypto/ssh golang.org/x/crypto/ssh/internal/bcrypt_pbkdf +# golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 +## explicit; go 1.20 +golang.org/x/exp/constraints +golang.org/x/exp/slices # golang.org/x/net v0.14.0 ## explicit; go 1.17 golang.org/x/net/context @@ -393,7 +397,7 @@ golang.org/x/text/width # golang.org/x/time v0.0.0-20220609170525-579cf78fd858 ## explicit golang.org/x/time/rate -# golang.org/x/tools v0.6.0 +# golang.org/x/tools v0.12.1-0.20230815132531-74c255bcf846 ## explicit; go 1.18 golang.org/x/tools/go/ast/inspector golang.org/x/tools/internal/typeparams