[go: nahoru, domu]

Skip to content

Commit

Permalink
Refactor utils package to not dump everything unrelated into one file (
Browse files Browse the repository at this point in the history
  • Loading branch information
ChaitanyaKulkarni28 committed Feb 7, 2024
1 parent 8aee0e4 commit bc138e1
Show file tree
Hide file tree
Showing 10 changed files with 253 additions and 225 deletions.
40 changes: 17 additions & 23 deletions google_guest_agent/addresses.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,19 @@ import (
"net"
"reflect"
"runtime"
"slices"
"strings"

"github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/cfg"
network "github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/network/manager"
"github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/run"
"github.com/GoogleCloudPlatform/guest-agent/utils"
"github.com/GoogleCloudPlatform/guest-logging-go/logger"
)

var (
addressKey = regKeyBase + `\ForwardedIps`
oldWSFCAddresses string
oldWSFCEnable bool
interfacesEnabled bool
interfaces []net.Interface
addressKey = regKeyBase + `\ForwardedIps`
oldWSFCAddresses string
oldWSFCEnable bool
)

type addressMgr struct{}
Expand Down Expand Up @@ -76,7 +74,9 @@ func getForwardsFromRegistry(mac string) ([]string, error) {
oldName := strings.Replace(mac, ":", "", -1)
regFwdIPs, err = readRegMultiString(addressKey, oldName)
if err == nil {
deleteRegKey(addressKey, oldName)
if err = deleteRegKey(addressKey, oldName); err != nil {
logger.Warningf("Failed to delete key: %q, name: %q from registry", addressKey, oldName)
}
}
} else if err != nil {
return nil, err
Expand All @@ -86,13 +86,13 @@ func getForwardsFromRegistry(mac string) ([]string, error) {

func compareRoutes(configuredRoutes, desiredRoutes []string) (toAdd, toRm []string) {
for _, desiredRoute := range desiredRoutes {
if !utils.ContainsString(desiredRoute, configuredRoutes) {
if !slices.Contains(configuredRoutes, desiredRoute) {
toAdd = append(toAdd, desiredRoute)
}
}

for _, configuredRoute := range configuredRoutes {
if !utils.ContainsString(configuredRoute, desiredRoutes) {
if !slices.Contains(desiredRoutes, configuredRoute) {
toRm = append(toRm, configuredRoute)
}
}
Expand Down Expand Up @@ -205,15 +205,15 @@ func (a *addressMgr) applyWSFCFilter(config *cfg.Sections) {
for idx := range interfaces {
var filteredForwardedIps []string
for _, ip := range interfaces[idx].ForwardedIps {
if !utils.ContainsString(ip, wsfcAddrs) {
if !slices.Contains(wsfcAddrs, ip) {
filteredForwardedIps = append(filteredForwardedIps, ip)
}
}
interfaces[idx].ForwardedIps = filteredForwardedIps

var filteredTargetInstanceIps []string
for _, ip := range interfaces[idx].TargetInstanceIps {
if !utils.ContainsString(ip, wsfcAddrs) {
if !slices.Contains(wsfcAddrs, ip) {
filteredTargetInstanceIps = append(filteredTargetInstanceIps, ip)
}
}
Expand Down Expand Up @@ -274,14 +274,8 @@ func (a *addressMgr) Set(ctx context.Context) error {
a.applyWSFCFilter(config)
}

var err error
interfaces, err = net.Interfaces()
if err != nil {
return fmt.Errorf("error populating interfaces: %v", err)
}

// Setup network interfaces.
err = network.SetupInterfaces(ctx, config, newMetadata.Instance.NetworkInterfaces)
err := network.SetupInterfaces(ctx, config, newMetadata.Instance.NetworkInterfaces)
if err != nil {
return fmt.Errorf("failed to setup network interfaces: %v", err)
}
Expand All @@ -295,7 +289,7 @@ func (a *addressMgr) Set(ctx context.Context) error {
for _, ni := range newMetadata.Instance.NetworkInterfaces {
iface, err := network.GetInterfaceByMAC(ni.Mac)
if err != nil {
if !utils.ContainsString(ni.Mac, badMAC) {
if !slices.Contains(badMAC, ni.Mac) {
logger.Errorf("Error getting interface: %s", err)
badMAC = append(badMAC, ni.Mac)
}
Expand Down Expand Up @@ -328,7 +322,7 @@ func (a *addressMgr) Set(ctx context.Context) error {
}
for _, ip := range configuredIPs {
// Only add to `forwardedIPs` if it is recorded in the registry.
if utils.ContainsString(ip, regFwdIPs) {
if slices.Contains(regFwdIPs, ip) {
forwardedIPs = append(forwardedIPs, ip)
}
}
Expand Down Expand Up @@ -371,14 +365,14 @@ func (a *addressMgr) Set(ctx context.Context) error {
var registryEntries []string
for _, ip := range wantIPs {
// If the IP is not in toAdd, add to registry list and continue.
if !utils.ContainsString(ip, toAdd) {
if !slices.Contains(toAdd, ip) {
registryEntries = append(registryEntries, ip)
continue
}
var err error
if runtime.GOOS == "windows" {
// Don't addAddress if this is already configured.
if !utils.ContainsString(ip, configuredIPs) {
if !slices.Contains(configuredIPs, ip) {
err = addAddress(net.ParseIP(ip), net.IPv4Mask(255, 255, 255, 255), uint32(iface.Index))
}
} else {
Expand All @@ -394,7 +388,7 @@ func (a *addressMgr) Set(ctx context.Context) error {
for _, ip := range toRm {
var err error
if runtime.GOOS == "windows" {
if !utils.ContainsString(ip, configuredIPs) {
if !slices.Contains(configuredIPs, ip) {
continue
}
err = removeAddress(net.ParseIP(ip), uint32(iface.Index))
Expand Down
3 changes: 2 additions & 1 deletion google_guest_agent/diagnostics.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"encoding/json"
"reflect"
"runtime"
"slices"
"sync/atomic"

"github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/cfg"
Expand Down Expand Up @@ -94,7 +95,7 @@ func (d *diagnosticsMgr) Set(ctx context.Context) error {
}

strEntry := newMetadata.Instance.Attributes.Diagnostics
if utils.ContainsString(strEntry, diagnosticsEntries) {
if slices.Contains(diagnosticsEntries, strEntry) {
return nil
}
diagnosticsEntries = append(diagnosticsEntries, strEntry)
Expand Down
3 changes: 2 additions & 1 deletion google_guest_agent/non_windows_accounts.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"os/exec"
"path"
"runtime"
"slices"
"sort"
"strconv"
"strings"
Expand Down Expand Up @@ -214,7 +215,7 @@ func getUserKeys(mdkeys []string) map[string][]string {
}

if err != nil {
if !utils.ContainsString(trimmedKey, badSSHKeys) {
if !slices.Contains(badSSHKeys, trimmedKey) {
logger.Errorf("%s: %s", err.Error(), trimmedKey)
badSSHKeys = append(badSSHKeys, trimmedKey)
}
Expand Down
3 changes: 2 additions & 1 deletion google_guest_agent/windows_accounts.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"math/big"
"reflect"
"runtime"
"slices"
"strconv"
"strings"

Expand Down Expand Up @@ -422,7 +423,7 @@ func compareAccounts(newKeys metadata.WindowsKeys, oldStrKeys []string) metadata
for _, s := range oldStrKeys {
var key metadata.WindowsKey
if err := json.Unmarshal([]byte(s), &key); err != nil {
if !utils.ContainsString(s, badReg) {
if !slices.Contains(badReg, s) {
logger.Errorf("Bad windows key from registry: %s", err)
badReg = append(badReg, s)
}
Expand Down
54 changes: 27 additions & 27 deletions google_guest_agent/windows_accounts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,8 @@ func TestAccountsDisabled(t *testing.T) {
}
}

// rename this with leading disabled because this is a resource
// intensive test. this test takes approx. 141 seconds to complete, next
// longest test is 0.43 seconds.
func disabledTestNewPwd(t *testing.T) {
// Test takes ~43 sec to complete and is resource intensive.
func TestNewPwd(t *testing.T) {
minPasswordLength := 15
maxPasswordLength := 255
var tests = []struct {
Expand All @@ -133,31 +131,33 @@ func disabledTestNewPwd(t *testing.T) {
}

for _, tt := range tests {
for i := 0; i < 100000; i++ {
pwd, err := newPwd(tt.passwordLength)
if err != nil {
t.Fatal(err)
}
if len(pwd) != tt.wantPasswordLength {
t.Errorf("Password is not %d characters: len(%s)=%d", tt.wantPasswordLength, pwd, len(pwd))
}
var l, u, n, s int
for _, r := range pwd {
switch {
case unicode.IsLower(r):
l = 1
case unicode.IsUpper(r):
u = 1
case unicode.IsDigit(r):
n = 1
case unicode.IsPunct(r) || unicode.IsSymbol(r):
s = 1
t.Run(tt.name, func(t *testing.T) {
for i := 0; i < 100000; i++ {
pwd, err := newPwd(tt.passwordLength)
if err != nil {
t.Fatal(err)
}
if len(pwd) != tt.wantPasswordLength {
t.Errorf("Password is not %d characters: len(%s)=%d", tt.wantPasswordLength, pwd, len(pwd))
}
var l, u, n, s int
for _, r := range pwd {
switch {
case unicode.IsLower(r):
l = 1
case unicode.IsUpper(r):
u = 1
case unicode.IsDigit(r):
n = 1
case unicode.IsPunct(r) || unicode.IsSymbol(r):
s = 1
}
}
if l+u+n+s < 3 {
t.Errorf("Password does not have at least one character from 3 categories: '%v'", pwd)
}
}
if l+u+n+s < 3 {
t.Errorf("Password does not have at least one character from 3 categories: '%v'", pwd)
}
}
})
}
}

Expand Down
80 changes: 80 additions & 0 deletions utils/file.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright 2024 Google LLC

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// https://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// OS file util for Google Guest Agent and Google Authorized Keys.

package utils

import (
"fmt"
"io/fs"
"os"
"path/filepath"
)

// SaferWriteFile writes to a temporary file and then replaces the expected output file.
// This prevents other processes from reading partial content while the writer is still writing.
func SaferWriteFile(content []byte, outputFile string, perm fs.FileMode) error {
dir := filepath.Dir(outputFile)
name := filepath.Base(outputFile)

if err := os.MkdirAll(dir, perm); err != nil {
return fmt.Errorf("unable to create required directories %q: %w", dir, err)
}

tmp, err := os.CreateTemp(dir, name+"*")
if err != nil {
return fmt.Errorf("unable to create temporary file under %q: %w", dir, err)
}

if err := os.Chmod(tmp.Name(), perm); err != nil {
return fmt.Errorf("unable to set permissions on temporary file %q: %w", dir, err)
}

if err := tmp.Close(); err != nil {
return fmt.Errorf("failed to close temporary file: %w", err)
}

if err := WriteFile(content, tmp.Name(), perm); err != nil {
return fmt.Errorf("unable to write to a temporary file %q: %w", tmp.Name(), err)
}

return os.Rename(tmp.Name(), outputFile)
}

// CopyFile copies content from src to dst and sets permissions.
func CopyFile(src, dst string, perm fs.FileMode) error {
b, err := os.ReadFile(src)
if err != nil {
return fmt.Errorf("failed to read %q: %w", src, err)
}

if err := WriteFile(b, dst, perm); err != nil {
return fmt.Errorf("failed to write %q: %w", dst, err)
}

if err := os.Chmod(dst, perm); err != nil {
return fmt.Errorf("unable to set permissions on destination file %q: %w", dst, err)
}

return nil
}

// WriteFile creates parent directories if required and writes content to the output file.
func WriteFile(content []byte, outputFile string, perm fs.FileMode) error {
if err := os.MkdirAll(filepath.Dir(outputFile), perm); err != nil {
return fmt.Errorf("unable to create required directories for %q: %w", outputFile, err)
}
return os.WriteFile(outputFile, content, perm)
}
Loading

0 comments on commit bc138e1

Please sign in to comment.