diff --git a/pkg/driver/node_server.go b/pkg/driver/node_server.go index 40e2e21..577b21b 100644 --- a/pkg/driver/node_server.go +++ b/pkg/driver/node_server.go @@ -203,7 +203,7 @@ func (d *Driver) NodeUnpublishVolume(ctx context.Context, req *csi.NodeUnpublish func (d *Driver) NodeGetInfo(ctx context.Context, req *csi.NodeGetInfoRequest) (*csi.NodeGetInfoResponse, error) { log.Info().Msg("Request: NodeGetInfo") - nodeInstanceID, region, err := currentNodeDetails() + nodeInstanceID, region, err := d.currentNodeDetails() if err != nil { return nil, status.Error(codes.Internal, err.Error()) } @@ -343,24 +343,41 @@ type civostatsdConfig struct { InstanceID string `toml:"instance_id"` } -func currentNodeDetails() (string, string, error) { +func (d *Driver) currentNodeDetails() (string, string, error) { configFile := "/etc/civostatsd" _, err := os.Stat(configFile) if err != nil { log.Debug().Msg("Node details file /etc/civostatsd doesn't existing, using ENVironment variables") - return currentNodeDetailsFromEnv() + return d.currentNodeDetailsFromEnv() } var config civostatsdConfig if _, err := toml.DecodeFile(configFile, &config); err != nil { log.Debug().Msg("Node details file /etc/civostatsd isn't valid TOML, using ENVironment variables") - return currentNodeDetailsFromEnv() + return d.currentNodeDetailsFromEnv() } return config.InstanceID, config.Region, nil } -func currentNodeDetailsFromEnv() (string, string, error) { +// Get the node details from the environment variables +// NODE_ID is the ID of the node that can be used to access details from the CIVO API +// REGION is the region that the node is in +// If NODE_ID is not set, then the KUBE_NODE_NAME is used to fetch the node using it's name +func (d *Driver) currentNodeDetailsFromEnv() (string, string, error) { + if os.Getenv("NODE_ID") == "" { + nodeName := os.Getenv("KUBE_NODE_NAME") + if nodeName == "" { + return "", "", fmt.Errorf("NODE_ID is not set and KUBE_NODE_NAME is not set") + } + + instance, err := d.CivoClient.FindKubernetesClusterInstance(d.ClusterID, nodeName) + if err != nil { + return "", "", err + } + // Return the instance ID and the region + return instance.ID, instance.Region, nil + } return os.Getenv("NODE_ID"), os.Getenv("REGION"), nil }