diff --git a/cmd/containerd-shim-runhcs-v1/pod.go b/cmd/containerd-shim-runhcs-v1/pod.go index 1d2551ee4d..6684529690 100644 --- a/cmd/containerd-shim-runhcs-v1/pod.go +++ b/cmd/containerd-shim-runhcs-v1/pod.go @@ -128,11 +128,15 @@ func createPod(ctx context.Context, events publisher, req *task.CreateTaskReques layerFolders = s.Windows.LayerFolders } wopts := (opts).(*uvm.OptionsWCOW) - wopts.BootFiles, err = layers.GetWCOWUVMBootFilesFromLayers(ctx, req.Rootfs, layerFolders) - if err != nil { - return nil, err + if !wopts.SecurityPolicyEnabled { + // When security policy is enabled SpecToUVMCreateOpts + // above sets up the BootFiles, otherwise we get boot + // files from the rootfs/layerfolders passed to us. + wopts.BootFiles, err = layers.GetWCOWUVMBootFilesFromLayers(ctx, req.Rootfs, layerFolders) + if err != nil { + return nil, err + } } - parent, err = uvm.CreateWCOW(ctx, wopts) if err != nil { return nil, err diff --git a/cmd/gcs-sidecar/main.go b/cmd/gcs-sidecar/main.go new file mode 100644 index 0000000000..aa949d8c80 --- /dev/null +++ b/cmd/gcs-sidecar/main.go @@ -0,0 +1,241 @@ +//go:build windows +// +build windows + +package main + +import ( + "context" + "flag" + "fmt" + "net" + "os" + + "github.com/Microsoft/go-winio" + "github.com/Microsoft/hcsshim/internal/gcs/prot" + shimlog "github.com/Microsoft/hcsshim/internal/log" + "github.com/Microsoft/hcsshim/internal/oc" + "github.com/Microsoft/hcsshim/pkg/securitypolicy" + "github.com/sirupsen/logrus" + "go.opencensus.io/trace" + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/svc" + "golang.org/x/sys/windows/svc/debug" + + sidecar "github.com/Microsoft/hcsshim/internal/gcs-sidecar" +) + +var ( + defaultLogFile = "C:\\gcs-sidecar-logs.log" + defaultLogLevel = "trace" +) + +type handler struct { + fromsvc chan error +} + +// Accepts new connection and closes listener. +func acceptAndClose(ctx context.Context, l net.Listener) (net.Conn, error) { + var conn net.Conn + ch := make(chan error) + go func() { + var err error + conn, err = l.Accept() + ch <- err + }() + select { + case err := <-ch: + l.Close() + return conn, err + case <-ctx.Done(): + } + l.Close() + err := <-ch + if err == nil { + return conn, err + } + + if ctx.Err() != nil { + return nil, ctx.Err() + } + return nil, err +} + +func (h *handler) Execute(args []string, r <-chan svc.ChangeRequest, status chan<- svc.Status) (bool, uint32) { + const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown | svc.Accepted(windows.SERVICE_ACCEPT_PARAMCHANGE) + + status <- svc.Status{State: svc.StartPending, Accepts: 0} + // unblock runService() + h.fromsvc <- nil + + status <- svc.Status{State: svc.Running, Accepts: cmdsAccepted} + +loop: + for c := range r { + switch c.Cmd { + case svc.Interrogate: + status <- c.CurrentStatus + case svc.Stop, svc.Shutdown: + logrus.Println("Shutting service...!") + break loop + case svc.Pause: + status <- svc.Status{State: svc.Paused, Accepts: cmdsAccepted} + case svc.Continue: + status <- svc.Status{State: svc.Running, Accepts: cmdsAccepted} + default: + logrus.Printf("Unexpected service control request #%d", c) + } + } + + status <- svc.Status{State: svc.StopPending} + return false, 1 +} + +func runService(name string, isDebug bool) error { + h := &handler{ + fromsvc: make(chan error), + } + + var err error + go func() { + if isDebug { + err = debug.Run(name, h) + if err != nil { + logrus.Errorf("Error running service in debug mode.Err: %v", err) + } + } else { + err = svc.Run(name, h) + if err != nil { + logrus.Errorf("Error running service in Service Control mode.Err %v", err) + } + } + h.fromsvc <- err + }() + + // Wait for the first signal from the service handler. + logrus.Tracef("waiting for first signal from service handler\n") + return <-h.fromsvc +} + +func main() { + logLevel := flag.String("loglevel", + defaultLogLevel, + "Logging Level: trace, debug, info, warning, error, fatal, panic.") + logFile := flag.String("logfile", + defaultLogFile, + "Logging Target. Default is at C:\\gcs-sidecar-logs.log inside UVM") + + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "\nUsage of %s:\n", os.Args[0]) + flag.PrintDefaults() + fmt.Fprintf(os.Stderr, "Examples:\n") + fmt.Fprintf(os.Stderr, " %s -loglevel=trace -logfile=C:\\sidecarLogs.log \n", os.Args[0]) + } + + flag.Parse() + + ctx := context.Background() + logFileHandle, err := os.OpenFile(*logFile, os.O_RDWR|os.O_CREATE|os.O_SYNC|os.O_TRUNC, 0666) + if err != nil { + fmt.Printf("error opening file: %v", err) + } + defer logFileHandle.Close() + + logrus.AddHook(shimlog.NewHook()) + + level, err := logrus.ParseLevel(*logLevel) + if err != nil { + logrus.Fatal(err) + } + logrus.SetLevel(level) + logrus.SetOutput(logFileHandle) + trace.ApplyConfig(trace.Config{DefaultSampler: trace.AlwaysSample()}) + trace.RegisterExporter(&oc.LogrusExporter{}) + + if err := windows.SetStdHandle(windows.STD_ERROR_HANDLE, windows.Handle(logFileHandle.Fd())); err != nil { + logrus.WithError(err).Error("error redirecting handle") + return + } + os.Stderr = logFileHandle + + chsrv := make(chan error) + go func() { + defer close(chsrv) + + if err := runService("gcs-sidecar", false); err != nil { + logrus.Errorf("error starting gcs-sidecar service: %v", err) + } + + chsrv <- err + }() + + select { + case <-ctx.Done(): + logrus.Error("context deadline exceeded") + return + case r := <-chsrv: + if r != nil { + logrus.Error(r) + return + } + } + + // 1. Start external server to connect with inbox GCS + listener, err := winio.ListenHvsock(&winio.HvsockAddr{ + VMID: prot.HvGUIDLoopback, + ServiceID: prot.WindowsGcsHvsockServiceID, + }) + if err != nil { + logrus.WithError(err).Errorf("error starting listener for sidecar <-> inbox gcs communication") + return + } + + var gcsListener net.Listener = listener + gcsCon, err := acceptAndClose(ctx, gcsListener) + if err != nil { + logrus.WithError(err).Errorf("error accepting inbox GCS connection") + return + } + + // 2. Setup connection with external gcs connection started from hcsshim + hvsockAddr := &winio.HvsockAddr{ + VMID: prot.HvGUIDParent, + ServiceID: prot.WindowsSidecarGcsHvsockServiceID, + } + + logrus.WithFields(logrus.Fields{ + "hvsockAddr": hvsockAddr, + }).Tracef("Dialing to hcsshim external bridge at address %v", hvsockAddr) + shimCon, err := winio.Dial(ctx, hvsockAddr) + if err != nil { + logrus.WithError(err).Errorf("error dialing hcsshim external bridge") + return + } + + // gcs-sidecar can be used for non-confidentail hyperv wcow + // as well. So we do not always want to check for initialPolicyStance + var initialEnforcer securitypolicy.SecurityPolicyEnforcer + // TODO (kiashok/Mahati): The initialPolicyStance is set to allow + // only for dev. This will eventually be set to allow/deny depending on + // on whether SNP is supported or not. + initialPolicyStance := "allow" + switch initialPolicyStance { + case "allow": + initialEnforcer = &securitypolicy.OpenDoorSecurityPolicyEnforcer{} + logrus.Tracef("initial-policy-stance: allow") + case "deny": + initialEnforcer = &securitypolicy.ClosedDoorSecurityPolicyEnforcer{} + logrus.Tracef("initial-policy-stance: deny") + default: + logrus.Error("unknown initial-policy-stance") + } + + // 3. Create bridge and initializa + brdg := sidecar.NewBridge(shimCon, gcsCon, initialEnforcer) + brdg.AssignHandlers() + + // 3. Listen and serve for hcsshim requests. + err = brdg.ListenAndServeShimRequests() + if err != nil { + logrus.WithError(err).Errorf("failed to serve request") + } +} diff --git a/go.mod b/go.mod index fa2d7d097c..d0645144a5 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/containerd/errdefs v0.3.0 github.com/containerd/errdefs/pkg v0.3.0 github.com/containerd/go-runc v1.0.0 + github.com/containerd/log v0.1.0 github.com/containerd/protobuild v0.3.0 github.com/containerd/ttrpc v1.2.5 github.com/containerd/typeurl/v2 v2.2.0 @@ -52,7 +53,6 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/containerd/continuity v0.4.2 // indirect github.com/containerd/fifo v1.1.0 // indirect - github.com/containerd/log v0.1.0 // indirect github.com/containerd/stargz-snapshotter/estargz v0.14.3 // indirect github.com/coreos/go-systemd/v22 v22.5.0 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.5 // indirect diff --git a/internal/bridgeutils/commonutils/utilities.go b/internal/bridgeutils/commonutils/utilities.go new file mode 100644 index 0000000000..4409825ad1 --- /dev/null +++ b/internal/bridgeutils/commonutils/utilities.go @@ -0,0 +1,83 @@ +package commonutils + +import ( + "encoding/json" + "fmt" + "io" + "math" + "strconv" + + "github.com/Microsoft/hcsshim/internal/bridgeutils/gcserr" + "github.com/sirupsen/logrus" +) + +type ErrorRecord struct { + Result int32 // HResult + Message string + StackTrace string `json:",omitempty"` + ModuleName string + FileName string + Line uint32 + FunctionName string `json:",omitempty"` +} + +// UnmarshalJSONWithHresult unmarshals the given data into the given interface, and +// wraps any error returned in an HRESULT error. +func UnmarshalJSONWithHresult(data []byte, v interface{}) error { + if err := json.Unmarshal(data, v); err != nil { + return gcserr.WrapHresult(err, gcserr.HrVmcomputeInvalidJSON) + } + return nil +} + +// DecodeJSONWithHresult decodes the JSON from the given reader into the given +// interface, and wraps any error returned in an HRESULT error. +func DecodeJSONWithHresult(r io.Reader, v interface{}) error { + if err := json.NewDecoder(r).Decode(v); err != nil { + return gcserr.WrapHresult(err, gcserr.HrVmcomputeInvalidJSON) + } + return nil +} + +func SetErrorForResponseBaseUtil(errForResponse error, moduleName string) (hresult gcserr.Hresult, errorMessage string, newRecord ErrorRecord) { + errorMessage = errForResponse.Error() + stackString := "" + fileName := "" + // We use -1 as a sentinel if no line number found (or it cannot be parsed), + // but that will ultimately end up as [math.MaxUint32], so set it to that explicitly. + // (Still keep using -1 for backwards compatibility ...) + lineNumber := uint32(math.MaxUint32) + functionName := "" + if stack := gcserr.BaseStackTrace(errForResponse); stack != nil { + bottomFrame := stack[0] + stackString = fmt.Sprintf("%+v", stack) + fileName = fmt.Sprintf("%s", bottomFrame) + lineNumberStr := fmt.Sprintf("%d", bottomFrame) + if n, err := strconv.ParseUint(lineNumberStr, 10, 32); err == nil { + lineNumber = uint32(n) + } else { + logrus.WithFields(logrus.Fields{ + "line-number": lineNumberStr, + logrus.ErrorKey: err, + }).Error("opengcs::bridge::setErrorForResponseBase - failed to parse line number, using -1 instead") + } + functionName = fmt.Sprintf("%n", bottomFrame) + } + hresult, err := gcserr.GetHresult(errForResponse) + if err != nil { + // Default to using the generic failure HRESULT. + hresult = gcserr.HrFail + } + + newRecord = ErrorRecord{ + Result: int32(hresult), + Message: errorMessage, + StackTrace: stackString, + ModuleName: moduleName, + FileName: fileName, + Line: lineNumber, + FunctionName: functionName, + } + + return hresult, errorMessage, newRecord +} diff --git a/internal/guest/gcserr/errors.go b/internal/bridgeutils/gcserr/errors.go similarity index 100% rename from internal/guest/gcserr/errors.go rename to internal/bridgeutils/gcserr/errors.go diff --git a/internal/fsformatter/formatter_driver.go b/internal/fsformatter/formatter_driver.go new file mode 100644 index 0000000000..7317e4c779 --- /dev/null +++ b/internal/fsformatter/formatter_driver.go @@ -0,0 +1,288 @@ +//go:build windows +// +build windows + +package fsformatter + +import ( + "context" + "encoding/binary" + "syscall" + "unicode/utf16" + "unsafe" + + "github.com/pkg/errors" + "golang.org/x/sys/windows" +) + +// This file contains all the supporting structures needed to make +// an ioctl call to RefsFormatter. +const ( + ioctlKernelFormatVolumeFormat = 0x40001000 + // This is used to construct the disk path that refsFormatter + // understands. `harddisk%d` here refers to the disk number + // associated with the corresponding lun of the attached + // scsi device. + VirtualDevObjectPathFormat = "\\device\\harddisk%d\\partition0" + checksumTypeSha256 = uint16(4) + refsChecksumType = checksumTypeSha256 + maxSizeOfKernelFormatVolumeFormatRefsParameters = 16 * 8 // 128 bytes + sizeOfWchar = int(unsafe.Sizeof(uint16(0))) + kernelFormatVolumeMaxVolumeLabelLength = uint32(33 * sizeOfWchar) + kernelFormatVolumeWin32DriverPath = "\\\\?\\KernelFSFormatter" + // Allocate large enough buffer for output from fsFormatter + maxSizeOfOutputBuffer = uint32(512) + + // KERNEL_FORMAT_VOLUME_FORMAT_REFS_PARAMETERS member offsets + clusterSizeOffset = 0 + checksumTypeOffset = 4 + useDataIntegrityOffset = 6 + majorVersionOffset = 8 + minorVersionOffset = 10 +) + +type kernelFormatVolumeFilesystemTypes uint32 + +const ( + kernelFormatVolumeFilesystemTypeInvalid = kernelFormatVolumeFilesystemTypes(iota) + kernelFormatVolumeFilesystemTypeRefs = kernelFormatVolumeFilesystemTypes(1) + kernelFormatVolumeFilesystemTypeMax = kernelFormatVolumeFilesystemTypes(2) +) + +// We only want to allow refs formatting +func (filesystemType kernelFormatVolumeFilesystemTypes) String() string { + switch filesystemType { + case kernelFormatVolumeFilesystemTypeRefs: + return "KERNEL_FORMAT_VOLUME_FILESYSTEM_TYPE_REFS" + default: + return "Unknown" + } +} + +type kernelFormatVolumeFormatInputBufferFlags uint32 + +const ( + kernelFormatVolumeFormatInputBufferFlagNone = kernelFormatVolumeFormatInputBufferFlags(0x00000000) + kernelFormatVolumeFormatInputBufferFlagSuperFloppy = kernelFormatVolumeFormatInputBufferFlags(0x00000001) +) + +func (flag kernelFormatVolumeFormatInputBufferFlags) String() string { + switch flag { + case kernelFormatVolumeFormatInputBufferFlagNone: + return "kernelFormatVolumeFormatInputBufferFlagNone" + case kernelFormatVolumeFormatInputBufferFlagSuperFloppy: + return "kernelFormatVolumeFormatInputBufferFlagSuperFloppy" + default: + return "Unknown" + } +} + +type KernelFormatVolumeFormatRefsParameters struct { + ClusterSize uint32 + MetadataChecksumType uint16 + UseDataIntegrity bool + MajorVersion uint16 + MinorVersion uint16 +} + +type KernelFormatVolumeFormatFsParameters struct { + FileSystemType kernelFormatVolumeFilesystemTypes + // Represents a WCHAR character array + VolumeLabel [kernelFormatVolumeMaxVolumeLabelLength / uint32(sizeOfWchar)]uint16 + // Length of volume label in bytes + VolumeLabelLength uint16 + // RefsFormatterParams represents the following union + /* + union { + + KERNEL_FORMAT_VOLUME_FORMAT_REFS_PARAMETERS RefsParameters; + + // + // This structure can't grow in size nor change in alignment. 16 ULONGLONGs + // should be more than enough for supporting other filesystems down the + // line. This also serves to enforce 8 byte alignment. + // + Reserved [16]uint64 + }; + */ + RefsFormatterParams [128]byte +} + +type KernelFormatVolumeFormatInputBuffer struct { + Size uint64 + FsParameters KernelFormatVolumeFormatFsParameters + Flags kernelFormatVolumeFormatInputBufferFlags + Reserved [4]uint32 + // Size of DiskPathBuffer in bytes + DiskPathLength uint16 + // DiskPathBuffer holds the disk path. It represents a + // variable size WCHAR character array + DiskPathBuffer []uint16 +} + +type kernelFormatVolumeFormatOutputBufferFlags uint32 + +const kernelFormatVolumeFormatOutputBufferFlagsNone = kernelFormatVolumeFormatOutputBufferFlags(0x00000000) + +func (flag kernelFormatVolumeFormatOutputBufferFlags) String() string { + switch flag { + case kernelFormatVolumeFormatOutputBufferFlagsNone: + return "kernelFormatVolumeFormatOutputBufferFlagsNone" + default: + return "Unknown" + } +} + +type KernelFormarVolumeFormatOutputBuffer struct { + Size uint32 + Flags kernelFormatVolumeFormatOutputBufferFlags + Reserved [4]uint32 + // VolumePathLength holds size of VolumePathBuffer + // in bytes + VolumePathLength uint16 + // VolumePathBuffer holds the mounted volume path + // as returned from refsFormatter. It represents + // a variable size WCHAR character array + VolumePathBuffer []uint16 +} + +// GetVolumePathBufferOffset gets offset to KernelFormarVolumeFormatOutputBuffer{}.VolumePathBuffer +func GetVolumePathBufferOffset() uint32 { + volPathBufferOffset := uint32(unsafe.Sizeof(KernelFormarVolumeFormatOutputBuffer{}.Size) + + unsafe.Sizeof(KernelFormarVolumeFormatOutputBuffer{}.Flags) + + unsafe.Sizeof(KernelFormarVolumeFormatOutputBuffer{}.Reserved) + + unsafe.Sizeof(KernelFormarVolumeFormatOutputBuffer{}.VolumePathLength)) + + return volPathBufferOffset +} + +// getInputBufferSize gets the total size needed for input buffer +func getInputBufferSize(wcharDiskPathLength uint16) uint32 { + bufferSize := uint32(unsafe.Sizeof(KernelFormatVolumeFormatInputBuffer{}.Size)+ + /* This is specifically for the union in KernelFormatVolumeFormatFsParameters */ + unsafe.Offsetof(KernelFormatVolumeFormatFsParameters{}.RefsFormatterParams)+ + maxSizeOfKernelFormatVolumeFormatRefsParameters+ + unsafe.Sizeof(KernelFormatVolumeFormatInputBuffer{}.Flags)+ + unsafe.Sizeof(KernelFormatVolumeFormatInputBuffer{}.Reserved)+ + unsafe.Sizeof(KernelFormatVolumeFormatInputBuffer{}.DiskPathLength)) + + uint32(wcharDiskPathLength) + + return bufferSize +} + +// getInputBufferDiskPathBufferOffset gets offset to KernelFormatVolumeFormatInputBuffer{}.DiskPathBuffer +func getInputBufferDiskPathBufferOffset() uint32 { + diskPathBufferOffset := uint32(unsafe.Sizeof(KernelFormatVolumeFormatInputBuffer{}.Size) + + unsafe.Offsetof(KernelFormatVolumeFormatFsParameters{}.RefsFormatterParams) + + maxSizeOfKernelFormatVolumeFormatRefsParameters + + unsafe.Sizeof(KernelFormatVolumeFormatInputBuffer{}.Flags) + + unsafe.Sizeof(KernelFormatVolumeFormatInputBuffer{}.Reserved) + + unsafe.Sizeof(KernelFormatVolumeFormatInputBuffer{}.DiskPathLength)) + + return diskPathBufferOffset +} + +// KmFmtCreateFormatOutputBuffer formats an output buffer as expected +// by the fsFormatter driver +func KmFmtCreateFormatOutputBuffer() *KernelFormarVolumeFormatOutputBuffer { + buf := make([]uint16, maxSizeOfOutputBuffer) + outputBuffer := (*KernelFormarVolumeFormatOutputBuffer)(unsafe.Pointer(&buf[0])) + outputBuffer.Size = uint32(maxSizeOfOutputBuffer) + + return outputBuffer +} + +func toUTF16(s string) []uint16 { + return utf16.Encode([]rune(s)) +} + +// KmFmtCreateFormatInputBuffer formats an input buffer as expected +// by the refsFormatter driver. +// diskPath represents disk path in VirtualDevObjectPathFormat. +func KmFmtCreateFormatInputBuffer(diskPath string) *KernelFormatVolumeFormatInputBuffer { + refsParametersBuf := make([]byte, unsafe.Sizeof(KernelFormatVolumeFormatRefsParameters{})) + refsParameters := (*KernelFormatVolumeFormatRefsParameters)(unsafe.Pointer(&refsParametersBuf[0])) + + utf16DiskPath := toUTF16(diskPath) + wcharDiskPathLength := uint16(len(utf16DiskPath) * sizeOfWchar) + + refsParameters.ClusterSize = 0x1000 + refsParameters.MetadataChecksumType = refsChecksumType + refsParameters.UseDataIntegrity = true + refsParameters.MajorVersion = uint16(3) + refsParameters.MinorVersion = uint16(14) + + bufferSize := getInputBufferSize(wcharDiskPathLength) + buf := make([]byte, bufferSize) + inputBuffer := (*KernelFormatVolumeFormatInputBuffer)(unsafe.Pointer(&buf[0])) + + inputBuffer.Size = uint64(bufferSize) + inputBuffer.Flags = kernelFormatVolumeFormatInputBufferFlagSuperFloppy + + inputBuffer.FsParameters.FileSystemType = kernelFormatVolumeFilesystemTypeRefs + inputBuffer.FsParameters.VolumeLabelLength = 0 + inputBuffer.FsParameters.VolumeLabel = [33]uint16{} + + // Write KERNEL_FORMAT_VOLUME_FORMAT_REFS_PARAMETERS + binary.LittleEndian.PutUint32(inputBuffer.FsParameters.RefsFormatterParams[clusterSizeOffset:], refsParameters.ClusterSize) + binary.LittleEndian.PutUint16(inputBuffer.FsParameters.RefsFormatterParams[checksumTypeOffset:], refsParameters.MetadataChecksumType) + if refsParameters.UseDataIntegrity { + inputBuffer.FsParameters.RefsFormatterParams[useDataIntegrityOffset] = 1 + } else { + inputBuffer.FsParameters.RefsFormatterParams[useDataIntegrityOffset] = 0 + } + binary.LittleEndian.PutUint16(inputBuffer.FsParameters.RefsFormatterParams[majorVersionOffset:], refsParameters.MajorVersion) + binary.LittleEndian.PutUint16(inputBuffer.FsParameters.RefsFormatterParams[minorVersionOffset:], refsParameters.MinorVersion) + + // Finally write the diskPathLength and diskPathBuffer with the input disk path + inputBuffer.DiskPathLength = wcharDiskPathLength + // DiskBuffer writing + ptr := unsafe.Add(unsafe.Pointer(inputBuffer), getInputBufferDiskPathBufferOffset()) + // Convert the string to UTF-16 slice + utf16Array := toUTF16(diskPath) + diskPathBuf := unsafe.Slice((*uint16)(ptr), len(utf16Array)) + copy(diskPathBuf, utf16Array) + + return inputBuffer +} + +// InvokeFsFormatter makes an ioctl call to the fsFormatter driver and returns +// a path to the mountedVolume +func InvokeFsFormatter(ctx context.Context, diskPath string) (string, error) { + // Prepare input and output buffers as expected by refsFormatter + inputBuffer := KmFmtCreateFormatInputBuffer(diskPath) + outputBuffer := KmFmtCreateFormatOutputBuffer() + + utf16DriverPath, _ := windows.UTF16PtrFromString(kernelFormatVolumeWin32DriverPath) + deviceHandle, err := windows.CreateFile(utf16DriverPath, + windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE, + 0, + nil, + windows.OPEN_EXISTING, + 0, + 0) + if err != nil { + return "", errors.Wrap(err, "failed to get handle to refsFormatter driver") + } + defer windows.Close(deviceHandle) + + // Ioctl to fsFormatter driver + var bytesReturned uint32 + if err := windows.DeviceIoControl( + deviceHandle, + ioctlKernelFormatVolumeFormat, + (*byte)(unsafe.Pointer(inputBuffer)), + uint32(inputBuffer.Size), + (*byte)(unsafe.Pointer(outputBuffer)), + outputBuffer.Size, + &bytesReturned, + nil, + ); err != nil { + return "", errors.Wrap(err, "ioctl to refsFormatter driver failed") + } + + // Read the returned volume path from the corresponding offset in outputBuffer + ptr := unsafe.Pointer(uintptr(unsafe.Pointer(outputBuffer)) + uintptr(GetVolumePathBufferOffset())) + utf16Data := unsafe.Slice((*uint16)(ptr), outputBuffer.VolumePathLength/2) + mountedVolumePath := syscall.UTF16ToString(utf16Data) + return mountedVolumePath, err +} diff --git a/internal/gcs-sidecar/bridge.go b/internal/gcs-sidecar/bridge.go new file mode 100644 index 0000000000..ae7b4b965a --- /dev/null +++ b/internal/gcs-sidecar/bridge.go @@ -0,0 +1,463 @@ +//go:build windows +// +build windows + +package bridge + +import ( + "bufio" + "bytes" + "context" + "encoding/base64" + "encoding/binary" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "sync" + + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "go.opencensus.io/trace" + "go.opencensus.io/trace/tracestate" + "golang.org/x/sys/windows" + + "github.com/Microsoft/go-winio/pkg/guid" + "github.com/Microsoft/hcsshim/internal/bridgeutils/commonutils" + "github.com/Microsoft/hcsshim/internal/bridgeutils/gcserr" + "github.com/Microsoft/hcsshim/internal/gcs/prot" + "github.com/Microsoft/hcsshim/internal/log" + "github.com/Microsoft/hcsshim/internal/oc" + "github.com/Microsoft/hcsshim/pkg/securitypolicy" +) + +type Bridge struct { + mu sync.Mutex + pendingMu sync.Mutex + pending map[sequenceID]*prot.ContainerExecuteProcessResponse + + hostState *Host + // List of handlers for handling different rpc message requests. + rpcHandlerList map[prot.RPCProc]HandlerFunc + + // hcsshim and inbox GCS connections respectively. + shimConn io.ReadWriteCloser + inboxGCSConn io.ReadWriteCloser + + // Response channels to forward incoming requests to inbox GCS + // and send responses back to hcsshim respectively. + sendToGCSCh chan request + sendToShimCh chan bridgeResponse +} + +// SequenceID is used to correlate requests and responses. +type sequenceID uint64 + +// messageHeader is the common header present in all communications messages. +type messageHeader struct { + Type prot.MsgType + Size uint32 + ID sequenceID +} + +type bridgeResponse struct { + ctx context.Context + header messageHeader + response []byte +} + +type request struct { + // Context created once received from the bridge. + ctx context.Context + // header is the wire format message header that preceded the message for + // this request. + header messageHeader + // activityID is the id of the specific activity for this request. + activityID guid.GUID + // message is the portion of the request that follows the `Header`. + message []byte +} + +func NewBridge(shimConn io.ReadWriteCloser, inboxGCSConn io.ReadWriteCloser, initialEnforcer securitypolicy.SecurityPolicyEnforcer) *Bridge { + hostState := NewHost(initialEnforcer) + return &Bridge{ + pending: make(map[sequenceID]*prot.ContainerExecuteProcessResponse), + rpcHandlerList: make(map[prot.RPCProc]HandlerFunc), + hostState: hostState, + shimConn: shimConn, + inboxGCSConn: inboxGCSConn, + sendToGCSCh: make(chan request), + sendToShimCh: make(chan bridgeResponse), + } +} + +// UnknownMessage represents the default handler logic for an unmatched request +// type sent from the bridge. +func UnknownMessage(r *request) error { + log.G(r.ctx).Debugf("bridge: function not supported, header type %v", prot.MsgType(r.header.Type).String()) + return gcserr.WrapHresult(errors.Errorf("bridge: function not supported, header type: %v", r.header.Type), gcserr.HrNotImpl) +} + +// HandlerFunc is an adapter to use functions as handlers. +type HandlerFunc func(*request) error + +func (b *Bridge) getRequestHandler(r *request) (HandlerFunc, error) { + b.mu.Lock() + defer b.mu.Unlock() + + var handler HandlerFunc + var ok bool + messageType := r.header.Type + rpcProcID := prot.RPCProc(messageType &^ prot.MsgTypeMask) + if handler, ok = b.rpcHandlerList[rpcProcID]; !ok { + return nil, UnknownMessage(r) + } + return handler, nil +} + +// ServeMsg serves request by calling appropriate handler functions. +func (b *Bridge) ServeMsg(r *request) error { + if r == nil { + panic("bridge: nil request to handler") + } + + var handler HandlerFunc + var err error + if handler, err = b.getRequestHandler(r); err != nil { + return UnknownMessage(r) + } + return handler(r) +} + +// Handle registers the handler for the given message id and protocol version. +func (b *Bridge) Handle(rpcProcID prot.RPCProc, handlerFunc HandlerFunc) { + b.mu.Lock() + defer b.mu.Unlock() + + if handlerFunc == nil { + panic("empty function handler") + } + + if _, ok := b.rpcHandlerList[rpcProcID]; ok { + logrus.WithFields(logrus.Fields{ + "message-type": rpcProcID.String(), + }).Warn("overwriting bridge handler") + } + + b.rpcHandlerList[rpcProcID] = handlerFunc +} + +func (b *Bridge) HandleFunc(rpcProcID prot.RPCProc, handler func(*request) error) { + if handler == nil { + panic("bridge: nil handler func") + } + + b.Handle(rpcProcID, HandlerFunc(handler)) +} + +// AssignHandlers creates and assigns appropriate event handlers +// for the different bridge message types. +func (b *Bridge) AssignHandlers() { + b.HandleFunc(prot.RPCCreate, b.createContainer) + b.HandleFunc(prot.RPCStart, b.startContainer) + b.HandleFunc(prot.RPCShutdownGraceful, b.shutdownGraceful) + b.HandleFunc(prot.RPCShutdownForced, b.shutdownForced) + b.HandleFunc(prot.RPCExecuteProcess, b.executeProcess) + b.HandleFunc(prot.RPCWaitForProcess, b.waitForProcess) + b.HandleFunc(prot.RPCSignalProcess, b.signalProcess) + b.HandleFunc(prot.RPCResizeConsole, b.resizeConsole) + b.HandleFunc(prot.RPCGetProperties, b.getProperties) + b.HandleFunc(prot.RPCModifySettings, b.modifySettings) + b.HandleFunc(prot.RPCNegotiateProtocol, b.negotiateProtocol) + b.HandleFunc(prot.RPCDumpStacks, b.dumpStacks) + b.HandleFunc(prot.RPCDeleteContainerState, b.deleteContainerState) + b.HandleFunc(prot.RPCUpdateContainer, b.updateContainer) + b.HandleFunc(prot.RPCLifecycleNotification, b.lifecycleNotification) +} + +// readMessage reads the message from io.Reader +func readMessage(r io.Reader) (messageHeader, []byte, error) { + var h [prot.HdrSize]byte + _, err := io.ReadFull(r, h[:]) + if err != nil { + return messageHeader{}, nil, err + } + var header messageHeader + buf := bytes.NewReader(h[:]) + err = binary.Read(buf, binary.LittleEndian, &header) + if err != nil { + logrus.WithError(err).Errorf("error reading message header") + return messageHeader{}, nil, err + } + + n := header.Size + if n < prot.HdrSize || n > prot.MaxMsgSize { + logrus.Errorf("invalid message size %d", n) + return messageHeader{}, nil, fmt.Errorf("invalid message size %d: %w", n, err) + } + + n -= prot.HdrSize + msg := make([]byte, n) + _, err = io.ReadFull(r, msg) + if err != nil { + if errors.Is(err, io.EOF) { + err = io.ErrUnexpectedEOF + } + return messageHeader{}, nil, err + } + + return header, msg, nil +} + +func isLocalDisconnectError(err error) bool { + return errors.Is(err, windows.WSAECONNABORTED) +} + +// Sends request to the inbox GCS channel +func (b *Bridge) forwardRequestToGcs(req *request) { + b.sendToGCSCh <- *req +} + +// Sends response to the hcsshim channel +func (b *Bridge) sendResponseToShim(ctx context.Context, rpcProcType prot.RPCProc, id sequenceID, response interface{}) error { + respType := prot.MsgTypeResponse | prot.MsgType(rpcProcType) + msgb, err := json.Marshal(response) + if err != nil { + return err + } + msgHeader := messageHeader{ + Type: respType, + Size: uint32(len(msgb) + prot.HdrSize), + ID: id, + } + + b.sendToShimCh <- bridgeResponse{ + ctx: ctx, + header: msgHeader, + response: msgb, + } + return nil +} + +func getContextAndSpan(baseSpanCtx *prot.Ocspancontext) (context.Context, *trace.Span) { + var ctx context.Context + var span *trace.Span + if baseSpanCtx != nil { + sc := trace.SpanContext{} + if bytes, err := hex.DecodeString(baseSpanCtx.TraceID); err == nil { + copy(sc.TraceID[:], bytes) + } + if bytes, err := hex.DecodeString(baseSpanCtx.SpanID); err == nil { + copy(sc.SpanID[:], bytes) + } + sc.TraceOptions = trace.TraceOptions(baseSpanCtx.TraceOptions) + if baseSpanCtx.Tracestate != "" { + if bytes, err := base64.StdEncoding.DecodeString(baseSpanCtx.Tracestate); err == nil { + var entries []tracestate.Entry + if err := json.Unmarshal(bytes, &entries); err == nil { + if ts, err := tracestate.New(nil, entries...); err == nil { + sc.Tracestate = ts + } + } + } + } + ctx, span = oc.StartSpanWithRemoteParent( + context.Background(), + "sidecar::request", + sc, + oc.WithServerSpanKind, + ) + } else { + ctx, span = oc.StartSpan( + context.Background(), + "sidecar::request", + oc.WithServerSpanKind, + ) + } + + return ctx, span +} + +// ListenAndServeShimRequests listens to messages on the hcsshim +// and inbox GCS connections and schedules them for processing. +// After processing, messages are forwarded to inbox GCS on success +// and responses from inbox GCS or error messages are sent back +// to hcsshim via bridge connection. +func (b *Bridge) ListenAndServeShimRequests() error { + shimRequestChan := make(chan request) + sidecarErrChan := make(chan error) + + defer b.inboxGCSConn.Close() + defer close(shimRequestChan) + defer close(sidecarErrChan) + defer b.shimConn.Close() + defer close(b.sendToShimCh) + defer close(b.sendToGCSCh) + + // Listen to requests from hcsshim + go func() { + var recverr error + br := bufio.NewReader(b.shimConn) + for { + header, msg, err := readMessage(br) + if err != nil { + if errors.Is(err, io.EOF) || isLocalDisconnectError(err) { + return + } + recverr = errors.Wrap(err, "bridge read from shim connection failed") + logrus.Error(recverr) + break + } + var msgBase prot.RequestBase + _ = json.Unmarshal(msg, &msgBase) + ctx, span := getContextAndSpan(msgBase.OpenCensusSpanContext) + span.AddAttributes( + trace.Int64Attribute("message-id", int64(header.ID)), + trace.StringAttribute("message-type", header.Type.String()), + trace.StringAttribute("activityID", msgBase.ActivityID.String()), + trace.StringAttribute("containerID", msgBase.ContainerID)) + + req := request{ + ctx: ctx, + activityID: msgBase.ActivityID, + header: header, + message: msg, + } + shimRequestChan <- req + } + sidecarErrChan <- recverr + }() + // Process each bridge request received from shim asynchronously. + go func() { + for req := range shimRequestChan { + // Requests are served sequentially to avoid + // racing/reordering of incoming message order. + // This becomes important for confidential cases + // where the shim could be compromised and replay + // messages out of order. + if err := b.ServeMsg(&req); err != nil { + log.G(req.ctx).WithError(err).Errorf("failed to serve request: %v", req.header.Type.String()) + // In case of error, create appropriate response message to + // be sent to hcsshim. + resp := &prot.ResponseBase{ + Result: int32(windows.ERROR_GEN_FAILURE), + ErrorMessage: err.Error(), + ActivityID: req.activityID, + } + setErrorForResponseBase(resp, err, "gcs-sidecar" /* moduleName */) + err = b.sendResponseToShim(req.ctx, prot.RPCProc(prot.MsgTypeResponse), req.header.ID, resp) + log.G(req.ctx).WithError(err).Errorf("failed to send response to shim") + } + } + }() + go func() { + var err error + for req := range b.sendToGCSCh { + // Forward message to gcs + log.G(req.ctx).Tracef("bridge send to gcs, req %v, %v", req.header.Type.String(), string(req.message)) + buffer, err := b.prepareResponseMessage(req.header, req.message) + if err != nil { + err = errors.Wrap(err, "error preparing response") + logrus.Error(err) + break + } + _, err = buffer.WriteTo(b.inboxGCSConn) + if err != nil { + err = errors.Wrap(err, "err forwarding shim req to inbox GCS") + logrus.Error(err) + break + } + } + sidecarErrChan <- err + }() + // Receive response from gcs and forward to hcsshim + go func() { + var recverr error + for { + header, message, err := readMessage(b.inboxGCSConn) + if err != nil { + if errors.Is(err, io.EOF) || isLocalDisconnectError(err) { + return + } + recverr = errors.Wrap(err, "bridge read from gcs failed") + logrus.Error(recverr) + break + } + // If this is a ContainerExecuteProcessResponse, notify + const MsgExecuteProcessResponse prot.MsgType = prot.MsgTypeResponse | prot.MsgType(prot.RPCExecuteProcess) + + if header.Type == MsgExecuteProcessResponse { + logrus.Tracef("Printing after inbox exec resp") + var procResp prot.ContainerExecuteProcessResponse + if err := json.Unmarshal(message, &procResp); err != nil { + logrus.Tracef("unmarshal failed") + } + + b.pendingMu.Lock() + if _, exists := b.pending[header.ID]; exists { + logrus.Tracef("Header ID in pending exists") + b.pending[header.ID] = &procResp + } + b.pendingMu.Unlock() + } + + // Forward to shim + resp := bridgeResponse{ + ctx: context.Background(), + header: header, + response: message, + } + b.sendToShimCh <- resp + } + sidecarErrChan <- recverr + }() + // Send response to hcsshim + go func() { + var sendErr error + for resp := range b.sendToShimCh { + // Send response to shim + logrus.Tracef("Send response to shim. Header:{ID: %v, Type: %v, Size: %v} msg: %v", resp.header.ID, + resp.header.Type, resp.header.Size, string(resp.response)) + buffer, err := b.prepareResponseMessage(resp.header, resp.response) + if err != nil { + sendErr = errors.Wrap(err, "error preparing response") + logrus.Error(sendErr) + break + } + _, sendErr = buffer.WriteTo(b.shimConn) + if sendErr != nil { + sendErr = errors.Wrap(sendErr, "err sending response to shim") + logrus.Error(sendErr) + break + } + } + sidecarErrChan <- sendErr + }() + + err := <-sidecarErrChan + return err +} + +// Prepare response message +func (b *Bridge) prepareResponseMessage(header messageHeader, message []byte) (bytes.Buffer, error) { + // Create a buffer to hold the serialized header data + var headerBuf bytes.Buffer + err := binary.Write(&headerBuf, binary.LittleEndian, header) + if err != nil { + return headerBuf, err + } + + // Write message header followed by actual payload. + var buf bytes.Buffer + buf.Write(headerBuf.Bytes()) + buf.Write(message[:]) + return buf, nil +} + +// setErrorForResponseBase modifies the passed-in ResponseBase to +// contain information pertaining to the given error. +func setErrorForResponseBase(response *prot.ResponseBase, errForResponse error, moduleName string) { + hresult, errorMessage, newRecord := commonutils.SetErrorForResponseBaseUtil(errForResponse, moduleName) + response.Result = int32(hresult) + response.ErrorMessage = errorMessage + response.ErrorRecords = append(response.ErrorRecords, newRecord) +} diff --git a/internal/gcs-sidecar/handlers.go b/internal/gcs-sidecar/handlers.go new file mode 100644 index 0000000000..d2b3688656 --- /dev/null +++ b/internal/gcs-sidecar/handlers.go @@ -0,0 +1,829 @@ +//go:build windows +// +build windows + +package bridge + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "github.com/Microsoft/hcsshim/hcn" + "github.com/Microsoft/hcsshim/internal/bridgeutils/commonutils" + "github.com/Microsoft/hcsshim/internal/fsformatter" + "github.com/Microsoft/hcsshim/internal/gcs/prot" + hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" + "github.com/Microsoft/hcsshim/internal/log" + "github.com/Microsoft/hcsshim/internal/oc" + "github.com/Microsoft/hcsshim/internal/oci" + "github.com/Microsoft/hcsshim/internal/protocol/guestrequest" + "github.com/Microsoft/hcsshim/internal/protocol/guestresource" + "github.com/Microsoft/hcsshim/internal/windevice" + "github.com/Microsoft/hcsshim/pkg/annotations" + "github.com/Microsoft/hcsshim/pkg/cimfs" + "github.com/Microsoft/hcsshim/pkg/securitypolicy" + "github.com/pkg/errors" + "golang.org/x/sys/windows" +) + +const ( + sandboxStateDirName = "WcSandboxState" + hivesDirName = "Hives" + devPathFormat = "\\\\.\\PHYSICALDRIVE%d" + UVMContainerID = "00000000-0000-0000-0000-000000000000" +) + +// - Handler functions handle the incoming message requests. It +// also enforces security policy for confidential cwcow containers. +// - These handler functions may do some additional processing before +// forwarding requests to inbox GCS or send responses back to hcsshim. +// - In case of any error encountered during processing, appropriate error +// messages are returned and responses are sent back to hcsshim from ListenAndServer(). +// TODO (kiashok): Verbose logging is for WIP and will be removed eventually. +func (b *Bridge) createContainer(req *request) (err error) { + ctx, span := oc.StartSpan(req.ctx, "sidecar::createContainer") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + var createContainerRequest prot.ContainerCreate + var containerConfig json.RawMessage + createContainerRequest.ContainerConfig.Value = &containerConfig + if err = commonutils.UnmarshalJSONWithHresult(req.message, &createContainerRequest); err != nil { + return errors.Wrap(err, "failed to unmarshal createContainer") + } + + // containerConfig can be of type uvnConfig or hcsschema.HostedSystem or guestresource.CWCOWHostedSystem + var ( + uvmConfig prot.UvmConfig + hostedSystemConfig hcsschema.HostedSystem + cwcowHostedSystemConfig guestresource.CWCOWHostedSystem + ) + if err = commonutils.UnmarshalJSONWithHresult(containerConfig, &uvmConfig); err == nil && + uvmConfig.SystemType != "" { + systemType := uvmConfig.SystemType + timeZoneInformation := uvmConfig.TimeZoneInformation + log.G(ctx).Tracef("createContainer: uvmConfig: {systemType: %v, timeZoneInformation: %v}}", systemType, timeZoneInformation) + } else if err = commonutils.UnmarshalJSONWithHresult(containerConfig, &hostedSystemConfig); err == nil && + hostedSystemConfig.SchemaVersion != nil && hostedSystemConfig.Container != nil { + schemaVersion := hostedSystemConfig.SchemaVersion + container := hostedSystemConfig.Container + log.G(ctx).Tracef("rpcCreate: HostedSystemConfig: {schemaVersion: %v, container: %v}}", schemaVersion, container) + } else if err = commonutils.UnmarshalJSONWithHresult(containerConfig, &cwcowHostedSystemConfig); err == nil && + cwcowHostedSystemConfig.Spec.Version != "" && cwcowHostedSystemConfig.CWCOWHostedSystem.Container != nil { + cwcowHostedSystem := cwcowHostedSystemConfig.CWCOWHostedSystem + schemaVersion := cwcowHostedSystem.SchemaVersion + container := cwcowHostedSystem.Container + spec := cwcowHostedSystemConfig.Spec + containerID := createContainerRequest.ContainerID + log.G(ctx).Tracef("rpcCreate: CWCOWHostedSystemConfig {spec: %v, schemaVersion: %v, container: %v}}", string(req.message), schemaVersion, container) + if b.hostState.isSecurityPolicyEnforcerInitialized() { + user := securitypolicy.IDName{ + Name: spec.Process.User.Username, + } + log.G(ctx).Tracef("user test: %v", user) + _, _, _, err := b.hostState.securityPolicyEnforcer.EnforceCreateContainerPolicyV2(req.ctx, containerID, spec.Process.Args, spec.Process.Env, spec.Process.Cwd, spec.Mounts, user, nil) + + if err != nil { + return fmt.Errorf("CreateContainer operation is denied by policy: %v", err) + } + c := &Container{ + id: containerID, + spec: spec, + processes: make(map[uint32]*containerProcess), + } + log.G(ctx).Tracef("Adding ContainerID: %v", containerID) + if err := b.hostState.AddContainer(req.ctx, containerID, c); err != nil { + log.G(ctx).Tracef("Container exists in the map!") + } + defer func(err error) { + if err != nil { + b.hostState.RemoveContainer(containerID) + } + }(err) + // Write security policy, signed UVM reference and host AMD certificate to + // container's rootfs, so that application and sidecar containers can have + // access to it. The security policy is required by containers which need to + // extract init-time claims found in the security policy. The directory path + // containing the files is exposed via UVM_SECURITY_CONTEXT_DIR env var. + // It may be an error to have a security policy but not expose it to the + // container as in that case it can never be checked as correct by a verifier. + if oci.ParseAnnotationsBool(ctx, spec.Annotations, annotations.UVMSecurityPolicyEnv, true) { + encodedPolicy := b.hostState.securityPolicyEnforcer.EncodedSecurityPolicy() + hostAMDCert := spec.Annotations[annotations.HostAMDCertificate] + if len(encodedPolicy) > 0 || len(hostAMDCert) > 0 || len(b.hostState.uvmReferenceInfo) > 0 { + // Use os.MkdirTemp to make sure that the directory is unique. + securityContextDir, err := os.MkdirTemp(spec.Root.Path, securitypolicy.SecurityContextDirTemplate) + if err != nil { + return fmt.Errorf("failed to create security context directory: %w", err) + } + // Make sure that files inside directory are readable + if err := os.Chmod(securityContextDir, 0755); err != nil { + return fmt.Errorf("failed to chmod security context directory: %w", err) + } + + if len(encodedPolicy) > 0 { + if err := writeFileInDir(securityContextDir, securitypolicy.PolicyFilename, []byte(encodedPolicy), 0777); err != nil { + return fmt.Errorf("failed to write security policy: %w", err) + } + } + if len(b.hostState.uvmReferenceInfo) > 0 { + if err := writeFileInDir(securityContextDir, securitypolicy.ReferenceInfoFilename, []byte(b.hostState.uvmReferenceInfo), 0777); err != nil { + return fmt.Errorf("failed to write UVM reference info: %w", err) + } + } + + if len(hostAMDCert) > 0 { + if err := writeFileInDir(securityContextDir, securitypolicy.HostAMDCertFilename, []byte(hostAMDCert), 0777); err != nil { + return fmt.Errorf("failed to write host AMD certificate: %w", err) + } + } + + containerCtxDir := fmt.Sprintf("/%s", filepath.Base(securityContextDir)) + secCtxEnv := fmt.Sprintf("UVM_SECURITY_CONTEXT_DIR=%s", containerCtxDir) + spec.Process.Env = append(spec.Process.Env, secCtxEnv) + } + } + } + + // Strip the spec field + hostedSystemBytes, err := json.Marshal(cwcowHostedSystem) + + if err != nil { + return fmt.Errorf("failed to marshal hostedSystem: %w", err) + } + + // marshal it again into a JSON-escaped string which inbox GCS expects + hostedSystemEscapedBytes, err := json.Marshal(string(hostedSystemBytes)) + if err != nil { + return fmt.Errorf("failed to marshal hostedSystem JSON: %w", err) + } + + // Prepare a fixed struct that takes in raw message + type containerCreateModified struct { + prot.RequestBase + ContainerConfig json.RawMessage + } + createContainerRequestModified := containerCreateModified{ + RequestBase: createContainerRequest.RequestBase, + ContainerConfig: hostedSystemEscapedBytes, + } + + buf, err := json.Marshal(createContainerRequestModified) + log.G(ctx).Tracef("marshaled request buffer: %s", string(buf)) + if err != nil { + return fmt.Errorf("failed to marshal rpcCreatecontainer: %v", err) + } + var newRequest request + newRequest.ctx = req.ctx + newRequest.header = req.header + newRequest.header.Size = uint32(len(buf)) + prot.HdrSize + newRequest.message = buf + req = &newRequest + } else { + return fmt.Errorf("invalid request to createContainer") + } + + b.forwardRequestToGcs(req) + return err +} + +func writeFileInDir(dir string, filename string, data []byte, perm os.FileMode) error { + st, err := os.Stat(dir) + if err != nil { + return err + } + + if !st.IsDir() { + return fmt.Errorf("not a directory %q", dir) + } + + targetFilename := filepath.Join(dir, filename) + return os.WriteFile(targetFilename, data, perm) +} + +// processParamEnvToOCIEnv converts an Environment field from ProcessParameters +// (a map from environment variable to value) into an array of environment +// variable assignments (where each is in the form "=") which +// can be used by an oci.Process. +func processParamEnvToOCIEnv(environment map[string]string) []string { + environmentList := make([]string, 0, len(environment)) + for k, v := range environment { + // TODO: Do we need to escape things like quotation marks in + // environment variable values? + environmentList = append(environmentList, fmt.Sprintf("%s=%s", k, v)) + } + return environmentList +} + +func (b *Bridge) startContainer(req *request) (err error) { + _, span := oc.StartSpan(req.ctx, "sidecar::startContainer") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + var r prot.RequestBase + if err := commonutils.UnmarshalJSONWithHresult(req.message, &r); err != nil { + return errors.Wrapf(err, "failed to unmarshal startContainer") + } + + b.forwardRequestToGcs(req) + return nil +} + +func (b *Bridge) shutdownGraceful(req *request) (err error) { + _, span := oc.StartSpan(req.ctx, "sidecar::shutdownGraceful") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + var r prot.RequestBase + if err := commonutils.UnmarshalJSONWithHresult(req.message, &r); err != nil { + return errors.Wrap(err, "failed to unmarshal shutdownGraceful") + } + + // TODO (kiashok/Mahati): Since gcs-sidecar can be used for all types of windows + // containers, it is important to check if we want to + // enforce policy or not. + if b.hostState.isSecurityPolicyEnforcerInitialized() { + b.hostState.securityPolicyEnforcer.EnforceShutdownContainerPolicy(req.ctx, r.ContainerID) + if err != nil { + return fmt.Errorf("shutdownGraceful operation not allowed: %v", err) + } + } + + b.forwardRequestToGcs(req) + return nil +} + +func (b *Bridge) shutdownForced(req *request) (err error) { + _, span := oc.StartSpan(req.ctx, "sidecar::shutdownForced") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + var r prot.RequestBase + if err := commonutils.UnmarshalJSONWithHresult(req.message, &r); err != nil { + return errors.Wrap(err, "failed to unmarshal shutdownForced") + } + + b.forwardRequestToGcs(req) + return nil +} + +// escapeArgs makes a Windows-style escaped command line from a set of arguments. +func escapeArgs(args []string) string { + escapedArgs := make([]string, len(args)) + for i, a := range args { + escapedArgs[i] = windows.EscapeArg(a) + } + return strings.Join(escapedArgs, " ") +} + +func (b *Bridge) executeProcess(req *request) (err error) { + _, span := oc.StartSpan(req.ctx, "sidecar::executeProcess") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + var r prot.ContainerExecuteProcess + var processParamSettings json.RawMessage + r.Settings.ProcessParameters.Value = &processParamSettings + if err := commonutils.UnmarshalJSONWithHresult(req.message, &r); err != nil { + return errors.Wrap(err, "failed to unmarshal executeProcess") + } + containerID := r.RequestBase.ContainerID + var processParams hcsschema.ProcessParameters + if err := commonutils.UnmarshalJSONWithHresult(processParamSettings, &processParams); err != nil { + return errors.Wrap(err, "executeProcess: invalid params type for request") + } + + commandLine := []string{processParams.CommandLine} + + if b.hostState.isSecurityPolicyEnforcerInitialized() { + if containerID == UVMContainerID { + log.G(req.ctx).Tracef("Enforcing policy on external exec process") + _, _, err := b.hostState.securityPolicyEnforcer.EnforceExecExternalProcessPolicy( + req.ctx, + commandLine, + processParamEnvToOCIEnv(processParams.Environment), + processParams.WorkingDirectory, + ) + if err != nil { + return errors.Wrapf(err, "exec is denied due to policy") + } + b.forwardRequestToGcs(req) + } else { + // fetch the container command line + c, err := b.hostState.GetCreatedContainer(req.ctx, containerID) + if err != nil { + log.G(req.ctx).Tracef("Container not found during exec: %v", containerID) + return errors.Wrapf(err, "containerID doesn't exist") + } + + // if this is an exec of Container command line, then it's already enforced + // during container creation, hence skip it here + containerCommandLine := escapeArgs(c.spec.Process.Args) + if processParams.CommandLine != containerCommandLine { + opts := &securitypolicy.ExecOptions{ + User: &securitypolicy.IDName{ + Name: processParams.User, + }, + } + log.G(req.ctx).Tracef("Enforcing policy on exec in container") + _, _, _, err = b.hostState.securityPolicyEnforcer. + EnforceExecInContainerPolicyV2( + req.ctx, + containerID, + commandLine, + processParamEnvToOCIEnv(processParams.Environment), + processParams.WorkingDirectory, + opts, + ) + if err != nil { + return errors.Wrapf(err, "exec in container denied due to policy") + } + } + headerID := req.header.ID + + // initiate process ID + b.pendingMu.Lock() + b.pending[headerID] = nil // nil means not yet received + b.pendingMu.Unlock() + + defer func() { + b.pendingMu.Lock() + delete(b.pending, headerID) + b.pendingMu.Unlock() + }() + + // forward the request to gcs + b.forwardRequestToGcs(req) + + // fetch the process ID from response + deadline := time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + log.G(req.ctx).Tracef("waiting for exec resp") + b.pendingMu.Lock() + resp := b.pending[headerID] + b.pendingMu.Unlock() + + // capture the Process details, so that we can later enforce + // on the allowed signals on the Process + if resp != nil { + log.G(req.ctx).Tracef("Got response: %+v", resp) + c.processesMutex.Lock() + defer c.processesMutex.Unlock() + c.processes[resp.ProcessID] = &containerProcess{ + processspec: processParams, + cid: c.id, + pid: resp.ProcessID, + } + return nil + } + time.Sleep(10 * time.Millisecond) // backoff + } + + return errors.Wrap(err, "timedout waiting for exec response") + } + } else { + b.forwardRequestToGcs(req) + } + return nil +} + +func (b *Bridge) waitForProcess(req *request) (err error) { + _, span := oc.StartSpan(req.ctx, "sidecar::waitForProcess") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + var r prot.ContainerWaitForProcess + if err := commonutils.UnmarshalJSONWithHresult(req.message, &r); err != nil { + return errors.Wrap(err, "failed to unmarshal waitForProcess") + } + + b.forwardRequestToGcs(req) + return nil +} + +func (b *Bridge) signalProcess(req *request) (err error) { + _, span := oc.StartSpan(req.ctx, "sidecar::signalProcess") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + var r prot.ContainerSignalProcess + var rawOpts json.RawMessage + r.Options = &rawOpts + if err := commonutils.UnmarshalJSONWithHresult(req.message, &r); err != nil { + return errors.Wrap(err, "failed to unmarshal signalProcess") + } + var wcowOptions guestresource.SignalProcessOptionsWCOW + if rawOpts != nil { + if err := commonutils.UnmarshalJSONWithHresult(rawOpts, &wcowOptions); err != nil { + return errors.Wrap(err, "signalProcess: invalid Options type for request") + } + + if b.hostState.isSecurityPolicyEnforcerInitialized() { + log.G(req.ctx).Tracef("RawOpts are not nil") + containerID := r.RequestBase.ContainerID + c, err := b.hostState.GetCreatedContainer(req.ctx, containerID) + if err != nil { + return err + } + + p, err := c.GetProcess(r.ProcessID) + if err != nil { + log.G(req.ctx).Tracef("Process not found %v", r.ProcessID) + return err + } + cmdLine := p.processspec.CommandLine + opts := &securitypolicy.SignalContainerOptions{ + IsInitProcess: false, + WindowsSignal: wcowOptions.Signal, + WindowsCommand: cmdLine, + } + err = b.hostState.securityPolicyEnforcer.EnforceSignalContainerProcessPolicyV2(req.ctx, containerID, opts) + if err != nil { + return err + } + } + + } + b.forwardRequestToGcs(req) + return nil +} + +func (b *Bridge) resizeConsole(req *request) (err error) { + _, span := oc.StartSpan(req.ctx, "sidecar::resizeConsole") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + var r prot.ContainerResizeConsole + if err := commonutils.UnmarshalJSONWithHresult(req.message, &r); err != nil { + return fmt.Errorf("failed to unmarshal resizeConsole: %v", req) + } + + b.forwardRequestToGcs(req) + return nil +} + +func (b *Bridge) getProperties(req *request) (err error) { + _, span := oc.StartSpan(req.ctx, "sidecar::getProperties") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + if b.hostState.isSecurityPolicyEnforcerInitialized() { + err := b.hostState.securityPolicyEnforcer.EnforceGetPropertiesPolicy(req.ctx) + if err != nil { + return errors.Wrapf(err, "get properties denied due to policy") + } + } + + var getPropReqV2 prot.ContainerGetPropertiesV2 + if err := commonutils.UnmarshalJSONWithHresult(req.message, &getPropReqV2); err != nil { + return errors.Wrapf(err, "failed to unmarshal getProperties: %v", string(req.message)) + } + log.G(req.ctx).Tracef("getProperties query: %v", getPropReqV2.Query.PropertyTypes) + + b.forwardRequestToGcs(req) + return nil +} + +func (b *Bridge) negotiateProtocol(req *request) (err error) { + _, span := oc.StartSpan(req.ctx, "sidecar::negotiateProtocol") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + var r prot.NegotiateProtocolRequest + if err := commonutils.UnmarshalJSONWithHresult(req.message, &r); err != nil { + return errors.Wrap(err, "failed to unmarshal negotiateProtocol") + } + + b.forwardRequestToGcs(req) + return nil +} + +func (b *Bridge) dumpStacks(req *request) (err error) { + _, span := oc.StartSpan(req.ctx, "sidecar::dumpStacks") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + var r prot.DumpStacksRequest + if err := commonutils.UnmarshalJSONWithHresult(req.message, &r); err != nil { + return errors.Wrap(err, "failed to unmarshal dumpStacks") + } + + b.forwardRequestToGcs(req) + return nil +} + +func (b *Bridge) deleteContainerState(req *request) (err error) { + _, span := oc.StartSpan(req.ctx, "sidecar::deleteContainerState") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + var r prot.DeleteContainerStateRequest + if err := commonutils.UnmarshalJSONWithHresult(req.message, &r); err != nil { + return errors.Wrap(err, "failed to unmarshal deleteContainerState") + } + + //TODO: Remove container state locally before passing it to inbox-gcs + /* + c, err := b.hostState.GetCreatedContainer(request.ContainerID) + if err != nil { + return nil, err + } + // remove container state regardless of delete's success + defer b.hostState.RemoveContainer(request.ContainerID)*/ + + b.forwardRequestToGcs(req) + return nil +} + +func (b *Bridge) updateContainer(req *request) (err error) { + _, span := oc.StartSpan(req.ctx, "sidecar::updateContainer") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + // No callers in the code for rpcUpdateContainer + b.forwardRequestToGcs(req) + return nil +} + +func (b *Bridge) lifecycleNotification(req *request) (err error) { + _, span := oc.StartSpan(req.ctx, "sidecar::lifecycleNotification") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + // No callers in the code for rpcLifecycleNotification + b.forwardRequestToGcs(req) + return nil +} + +func (b *Bridge) modifySettings(req *request) (err error) { + ctx, span := oc.StartSpan(req.ctx, "sidecar::modifySettings") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + log.G(ctx).Tracef("modifySettings: MsgType: %v, Payload: %v", req.header.Type, string(req.message)) + modifyRequest, err := unmarshalContainerModifySettings(req) + if err != nil { + return err + } + modifyGuestSettingsRequest := modifyRequest.Request.(*guestrequest.ModificationRequest) + guestResourceType := modifyGuestSettingsRequest.ResourceType + guestRequestType := modifyGuestSettingsRequest.RequestType + log.G(ctx).Tracef("modifySettings: resourceType: %v, requestType: %v", guestResourceType, guestRequestType) + + if guestRequestType == "" { + guestRequestType = guestrequest.RequestTypeAdd + } + + if guestRequestType == "" { + guestRequestType = guestrequest.RequestTypeAdd + } + + switch guestRequestType { + case guestrequest.RequestTypeAdd: + case guestrequest.RequestTypeRemove: + case guestrequest.RequestTypePreAdd: + case guestrequest.RequestTypeUpdate: + default: + return fmt.Errorf("invald guestRequestType %v", guestRequestType) + } + + if guestResourceType != "" { + switch guestResourceType { + case guestresource.ResourceTypeCombinedLayers: + settings := modifyGuestSettingsRequest.Settings.(*guestresource.WCOWCombinedLayers) + log.G(ctx).Tracef("WCOWCombinedLayers: {%v}", settings) + + case guestresource.ResourceTypeNetworkNamespace: + settings := modifyGuestSettingsRequest.Settings.(*hcn.HostComputeNamespace) + log.G(ctx).Tracef("HostComputeNamespaces { %v}", settings) + + case guestresource.ResourceTypeNetwork: + settings := modifyGuestSettingsRequest.Settings.(*guestrequest.NetworkModifyRequest) + log.G(ctx).Tracef("NetworkModifyRequest { %v}", settings) + + case guestresource.ResourceTypeMappedVirtualDisk: + wcowMappedVirtualDisk := modifyGuestSettingsRequest.Settings.(*guestresource.WCOWMappedVirtualDisk) + log.G(ctx).Tracef("wcowMappedVirtualDisk { %v}", wcowMappedVirtualDisk) + + case guestresource.ResourceTypeHvSocket: + hvSocketAddress := modifyGuestSettingsRequest.Settings.(*hcsschema.HvSocketAddress) + log.G(ctx).Tracef("hvSocketAddress { %v }", hvSocketAddress) + + case guestresource.ResourceTypeMappedDirectory: + settings := modifyGuestSettingsRequest.Settings.(*hcsschema.MappedDirectory) + log.G(ctx).Tracef("hcsschema.MappedDirectory { %v }", settings) + + case guestresource.ResourceTypeSecurityPolicy: + securityPolicyRequest := modifyGuestSettingsRequest.Settings.(*guestresource.WCOWConfidentialOptions) + log.G(ctx).Tracef("WCOWConfidentialOptions: { %v}", securityPolicyRequest) + err := b.hostState.SetWCOWConfidentialUVMOptions(req.ctx, securityPolicyRequest) + if err != nil { + return errors.Wrap(err, "error creating enforcer") + } + // Send response back to shim + resp := &prot.ResponseBase{ + Result: 0, // 0 means success + ActivityID: req.activityID, + } + err = b.sendResponseToShim(req.ctx, prot.RPCModifySettings, req.header.ID, resp) + if err != nil { + return errors.Wrap(err, "error sending response to hcsshim") + } + return nil + case guestresource.ResourceTypePolicyFragment: + r, ok := modifyGuestSettingsRequest.Settings.(*guestresource.LCOWSecurityPolicyFragment) + if !ok { + return errors.New("the request settings are not of type LCOWSecurityPolicyFragment") + } + return b.hostState.InjectFragment(ctx, r) + case guestresource.ResourceTypeWCOWBlockCims: + // This is request to mount the merged cim at given volumeGUID + wcowBlockCimMounts := modifyGuestSettingsRequest.Settings.(*guestresource.WCOWBlockCIMMounts) + containerID := wcowBlockCimMounts.ContainerID + log.G(ctx).Tracef("WCOWBlockCIMMounts { %v}", wcowBlockCimMounts) + + // The block device takes some time to show up. Wait for a few seconds. + time.Sleep(2 * time.Second) + + var layerCIMs []*cimfs.BlockCIM + layerHashes := make([]string, len(wcowBlockCimMounts.BlockCIMs)) + layerDigests := make([][]byte, len(wcowBlockCimMounts.BlockCIMs)) + ctx := req.ctx + for i, blockCimDevice := range wcowBlockCimMounts.BlockCIMs { + // Get the scsi device path for the blockCim lun + devNumber, err := windevice.GetDeviceNumberFromControllerLUN( + ctx, + 0, /* controller is always 0 for wcow */ + uint8(blockCimDevice.Lun)) + if err != nil { + return errors.Wrap(err, "err getting scsiDevPath") + } + physicalDevPath := fmt.Sprintf(devPathFormat, devNumber) + layerCim := cimfs.BlockCIM{ + Type: cimfs.BlockCIMTypeDevice, + BlockPath: physicalDevPath, + CimName: blockCimDevice.CimName, + } + cimRootDigestBytes, err := cimfs.GetVerificationInfo(physicalDevPath) + if err != nil { + return fmt.Errorf("failed to get CIM verification info: %w", err) + } + layerDigests[i] = cimRootDigestBytes + layerHashes[i] = base64.URLEncoding.EncodeToString(cimRootDigestBytes) + layerCIMs = append(layerCIMs, &layerCim) + log.G(ctx).Debugf("block CIM layer digest %s, path: %s\n", layerHashes[i], physicalDevPath) + } + + // skip the merged cim and verify individual layer hashes + hashesToVerify := layerHashes + if len(layerHashes) > 1 { + hashesToVerify = layerHashes[1:] + } + + err := b.hostState.securityPolicyEnforcer.EnforceVerifiedCIMsPolicy(req.ctx, containerID, hashesToVerify) + if err != nil { + return errors.Wrap(err, "CIM mount is denied by policy") + } + + if len(layerCIMs) > 1 { + // Get the topmost merge CIM and invoke the MountMergedBlockCIMs + _, err := cimfs.MountMergedBlockCIMs(layerCIMs[0], layerCIMs[1:], wcowBlockCimMounts.MountFlags, wcowBlockCimMounts.VolumeGuid) + if err != nil { + return errors.Wrap(err, "error mounting multilayer block cims") + } + } else { + _, err := cimfs.MountVerifiedBlockCIM(layerCIMs[0], wcowBlockCimMounts.MountFlags, wcowBlockCimMounts.VolumeGuid, layerDigests[0]) + if err != nil { + return errors.Wrap(err, "error mounting merged block cims") + } + } + + // Send response back to shim + resp := &prot.ResponseBase{ + Result: 0, // 0 means success + ActivityID: req.activityID, + } + err = b.sendResponseToShim(req.ctx, prot.RPCModifySettings, req.header.ID, resp) + if err != nil { + return errors.Wrap(err, "error sending response to hcsshim") + } + return nil + + case guestresource.ResourceTypeCWCOWCombinedLayers: + settings := modifyGuestSettingsRequest.Settings.(*guestresource.CWCOWCombinedLayers) + containerID := settings.ContainerID + log.G(ctx).Tracef("CWCOWCombinedLayers:: ContainerID: %v, ContainerRootPath: %v, Layers: %v, ScratchPath: %v", + containerID, settings.CombinedLayers.ContainerRootPath, settings.CombinedLayers.Layers, settings.CombinedLayers.ScratchPath) + + // check that this is not denied by policy + // TODO: modify gcs-sidecar code to pass context across all calls + // TODO: Update modifyCombinedLayers with verified CimFS API + if b.hostState.isSecurityPolicyEnforcerInitialized() { + policy_err := modifyCombinedLayers(ctx, containerID, guestRequestType, settings.CombinedLayers, b.hostState.securityPolicyEnforcer) + if policy_err != nil { + return errors.Wrapf(policy_err, "CimFS layer mount is denied by policy: %v", settings) + } + } + + // TODO: Update modifyCombinedLayers with verified CimFS API + + // The following two folders are expected to be present in the scratch. + // But since we have just formatted the scratch we would need to + // create them manually. + sandboxStateDirectory := filepath.Join(settings.CombinedLayers.ContainerRootPath, sandboxStateDirName) + err = os.Mkdir(sandboxStateDirectory, 0777) + if err != nil { + return errors.Wrap(err, "failed to create sandboxStateDirectory") + } + + hivesDirectory := filepath.Join(settings.CombinedLayers.ContainerRootPath, hivesDirName) + err = os.Mkdir(hivesDirectory, 0777) + if err != nil { + return errors.Wrap(err, "failed to create hivesDirectory") + } + + // Reconstruct WCOWCombinedLayers{} req before forwarding to GCS + // as GCS does not understand ResourceTypeCWCOWCombinedLayers + modifyGuestSettingsRequest.ResourceType = guestresource.ResourceTypeCombinedLayers + modifyGuestSettingsRequest.Settings = settings.CombinedLayers + modifyRequest.Request = modifyGuestSettingsRequest + buf, err := json.Marshal(modifyRequest) + if err != nil { + return errors.Wrap(err, "failed to marshal rpcModifySettings") + } + var newRequest request + newRequest.ctx = req.ctx + newRequest.header = req.header + newRequest.header.Size = uint32(len(buf)) + prot.HdrSize + newRequest.message = buf + req = &newRequest + + case guestresource.ResourceTypeMappedVirtualDiskForContainerScratch: + wcowMappedVirtualDisk := modifyGuestSettingsRequest.Settings.(*guestresource.WCOWMappedVirtualDisk) + log.G(ctx).Tracef("ResourceTypeMappedVirtualDiskForContainerScratch: { %v }", wcowMappedVirtualDisk) + + // 1. TODO (Mahati): Need to enforce policy before calling into fsFormatter + // 2. Call fsFormatter to format the scratch disk. + // This will return the volume path of the mounted scratch. + // Scratch disk should be >= 30 GB for refs formatter to work. + + // fsFormatter understands only virtualDevObjectPathFormat. Therefore fetch the + // disk number for the corresponding lun + var devNumber uint32 + // It could take a few seconds for the attached scsi disk + // to show up inside the UVM. Therefore adding retry logic + // with delay here. + for try := 0; try < 5; try++ { + time.Sleep(1 * time.Second) + devNumber, err = windevice.GetDeviceNumberFromControllerLUN(req.ctx, + 0, /* Only one controller allowed in wcow hyperv */ + uint8(wcowMappedVirtualDisk.Lun)) + if err != nil { + if try == 4 { + // bail out + return errors.Wrapf(err, "error getting diskNumber for LUN %d", wcowMappedVirtualDisk.Lun) + } + continue + } else { + log.G(ctx).Tracef("DiskNumber of lun %d is: %d", wcowMappedVirtualDisk.Lun, devNumber) + break + } + } + diskPath := fmt.Sprintf(fsformatter.VirtualDevObjectPathFormat, devNumber) + log.G(ctx).Tracef("diskPath: %v, diskNumber: %v ", diskPath, devNumber) + mountedVolumePath, err := fsformatter.InvokeFsFormatter(req.ctx, diskPath) + if err != nil { + return errors.Wrap(err, "failed to invoke refsFormatter") + } + log.G(ctx).Tracef("mountedVolumePath returned from InvokeFsFormatter: %v", mountedVolumePath) + + // Forward the req as is to inbox gcs and let it retreive the volume. + // While forwarding request to inbox gcs, make sure to replace the + // resourceType to ResourceTypeMappedVirtualDisk that inbox GCS + // understands. + modifyGuestSettingsRequest.ResourceType = guestresource.ResourceTypeMappedVirtualDisk + modifyRequest.Request = modifyGuestSettingsRequest + buf, err := json.Marshal(modifyRequest) + if err != nil { + return errors.Wrap(err, "failed to marshal WCOWMappedVirtualDisk") + } + var newRequest request + newRequest.ctx = req.ctx + newRequest.header = req.header + newRequest.header.Size = uint32(len(buf)) + prot.HdrSize + newRequest.message = buf + req = &newRequest + + default: + // Invalid request + return fmt.Errorf("invald modifySettingsRequest: %v", guestResourceType) + } + } + + b.forwardRequestToGcs(req) + return nil +} diff --git a/internal/gcs-sidecar/host.go b/internal/gcs-sidecar/host.go new file mode 100644 index 0000000000..c1e0e6e8c3 --- /dev/null +++ b/internal/gcs-sidecar/host.go @@ -0,0 +1,246 @@ +//go:build windows +// +build windows + +package bridge + +import ( + "context" + "crypto/sha256" + "encoding/base64" + "fmt" + "os" + "path/filepath" + "sync" + "time" + + "github.com/Microsoft/cosesign1go/pkg/cosesign1" + didx509resolver "github.com/Microsoft/didx509go/pkg/did-x509-resolver" + "github.com/Microsoft/hcsshim/internal/bridgeutils/gcserr" + hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" + "github.com/Microsoft/hcsshim/internal/log" + "github.com/Microsoft/hcsshim/internal/logfields" + "github.com/Microsoft/hcsshim/internal/protocol/guestresource" + "github.com/Microsoft/hcsshim/internal/pspdriver" + "github.com/Microsoft/hcsshim/pkg/securitypolicy" + oci "github.com/opencontainers/runtime-spec/specs-go" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" +) + +type Host struct { + containersMutex sync.Mutex + containers map[string]*Container + + // state required for the security policy enforcement + policyMutex sync.Mutex + securityPolicyEnforcer securitypolicy.SecurityPolicyEnforcer + securityPolicyEnforcerSet bool + uvmReferenceInfo string +} + +type Container struct { + id string + spec oci.Spec + processesMutex sync.Mutex + processes map[uint32]*containerProcess +} + +// Process is a struct that defines the lifetime and operations associated with +// an oci.Process. +type containerProcess struct { + processspec hcsschema.ProcessParameters + // cid is the container id that owns this process. + cid string + pid uint32 +} + +func NewHost(initialEnforcer securitypolicy.SecurityPolicyEnforcer) *Host { + return &Host{ + containers: make(map[string]*Container), + securityPolicyEnforcer: initialEnforcer, + securityPolicyEnforcerSet: false, + } +} + +// InjectFragment extends current security policy with additional constraints +// from the incoming fragment. Note that it is base64 encoded over the bridge/ +// +// There are three checking steps: +// 1 - Unpack the cose document and check it was actually signed with the cert +// chain inside its header +// 2 - Check that the issuer field did:x509 identifier is for that cert chain +// (ie fingerprint of a non leaf cert and the subject matches the leaf cert) +// 3 - Check that this issuer/feed match the requirement of the user provided +// security policy (done in the regoby LoadFragment) +func (h *Host) InjectFragment(ctx context.Context, fragment *guestresource.LCOWSecurityPolicyFragment) (err error) { + log.G(ctx).WithField("fragment", fmt.Sprintf("%+v", fragment)).Debug("GCS Host.InjectFragment") + + raw, err := base64.StdEncoding.DecodeString(fragment.Fragment) + if err != nil { + return err + } + blob := []byte(fragment.Fragment) + // keep a copy of the fragment, so we can manually figure out what went wrong + // will be removed eventually. Give it a unique name to avoid any potential + // race conditions. + sha := sha256.New() + sha.Write(blob) + timestamp := time.Now() + fragmentPath := fmt.Sprintf("fragment-%x-%d.blob", sha.Sum(nil), timestamp.UnixMilli()) + _ = os.WriteFile(filepath.Join(os.TempDir(), fragmentPath), blob, 0644) + + unpacked, err := cosesign1.UnpackAndValidateCOSE1CertChain(raw) + if err != nil { + return fmt.Errorf("InjectFragment failed COSE validation: %w", err) + } + + payloadString := string(unpacked.Payload[:]) + issuer := unpacked.Issuer + feed := unpacked.Feed + chainPem := unpacked.ChainPem + + log.G(ctx).WithFields(logrus.Fields{ + "issuer": issuer, // eg the DID:x509:blah.... + "feed": feed, + "cty": unpacked.ContentType, + "chainPem": chainPem, + }).Debugf("unpacked COSE1 cert chain") + + log.G(ctx).WithFields(logrus.Fields{ + "payload": payloadString, + }).Tracef("unpacked COSE1 payload") + + if len(issuer) == 0 || len(feed) == 0 { // must both be present + return fmt.Errorf("either issuer and feed must both be provided in the COSE_Sign1 protected header") + } + + // Resolve returns a did doc that we don't need + // we only care if there was an error or not + _, err = didx509resolver.Resolve(unpacked.ChainPem, issuer, true) + if err != nil { + log.G(ctx).Printf("Badly formed fragment - did resolver failed to match fragment did:x509 from chain with purported issuer %s, feed %s - err %s", issuer, feed, err.Error()) + return err + } + + // now offer the payload fragment to the policy + err = h.securityPolicyEnforcer.LoadFragment(ctx, issuer, feed, payloadString) + if err != nil { + return fmt.Errorf("InjectFragment failed policy load: %w", err) + } + log.G(ctx).Printf("passed fragment into the enforcer.") + + return nil +} + +func (h *Host) isSecurityPolicyEnforcerInitialized() bool { + return h.securityPolicyEnforcer != nil +} + +func (h *Host) SetWCOWConfidentialUVMOptions(ctx context.Context, securityPolicyRequest *guestresource.WCOWConfidentialOptions) error { + h.policyMutex.Lock() + defer h.policyMutex.Unlock() + + if h.securityPolicyEnforcerSet { + return errors.New("security policy has already been set") + } + + if securityPolicyRequest.NoSecurityHardware || pspdriver.IsSNPEnabled(ctx) { + log.G(ctx).Tracef("Starting psp driver") + // Start the psp driver + if err := pspdriver.StartPSPDriver(ctx); err != nil { + // Failed to start psp driver, return prematurely + return errors.Wrapf(err, "failed to start PSP driver") + } + } else { + // failed to load PSP driver, error out + // TODO (kiashok): Following log can be cleaned up once the caller stops ignoring failure + // due to "rego" error. + log.G(ctx).Fatal("failed to load PSP driver: no hardware support or annotation specified") + return fmt.Errorf("failed to load PSP driver: no hardware support or annotation specified") + } + + // This limit ensures messages are below the character truncation limit that + // can be imposed by an orchestrator + maxErrorMessageLength := 3 * 1024 + + // Initialize security policy enforcer for a given enforcer type and + // encoded security policy. + p, err := securitypolicy.CreateSecurityPolicyEnforcer( + "rego", + securityPolicyRequest.EncodedSecurityPolicy, + DefaultCRIMounts(), + DefaultCRIPrivilegedMounts(), + maxErrorMessageLength, + "windows", + ) + if err != nil { + return fmt.Errorf("error creating security policy enforcer: %w", err) + } + + if err = p.EnforceRuntimeLoggingPolicy(ctx); err == nil { + // TODO: enable OTL logging + //logrus.SetOutput(h.logWriter) + } else { + // TODO: disable OTL logging + //logrus.SetOutput(io.Discard) + } + + h.securityPolicyEnforcer = p + h.securityPolicyEnforcerSet = true + + return nil +} + +func (h *Host) AddContainer(ctx context.Context, id string, c *Container) error { + h.containersMutex.Lock() + defer h.containersMutex.Unlock() + + if _, ok := h.containers[id]; ok { + log.G(ctx).Tracef("Container exists in the map: %v", ok) + } + log.G(ctx).Tracef("AddContainer: ID: %v", id) + h.containers[id] = c + return nil +} + +func (h *Host) RemoveContainer(id string) { + h.containersMutex.Lock() + defer h.containersMutex.Unlock() + + _, ok := h.containers[id] + if !ok { + return + } + + delete(h.containers, id) +} + +func (h *Host) GetCreatedContainer(ctx context.Context, id string) (*Container, error) { + h.containersMutex.Lock() + defer h.containersMutex.Unlock() + + c, ok := h.containers[id] + if !ok { + return nil, gcserr.NewHresultError(gcserr.HrVmcomputeSystemNotFound) + } + return c, nil +} + +// GetProcess returns the Process with the matching 'pid'. If the 'pid' does +// not exit returns error. +func (c *Container) GetProcess(pid uint32) (*containerProcess, error) { + //todo: thread a context to this function call + logrus.WithFields(logrus.Fields{ + logfields.ContainerID: c.id, + logfields.ProcessID: pid, + }).Info("opengcs::Container::GetProcess") + + c.processesMutex.Lock() + defer c.processesMutex.Unlock() + + p, ok := c.processes[pid] + if !ok { + return nil, gcserr.NewHresultError(gcserr.HrErrNotFound) + } + return p, nil +} diff --git a/internal/gcs-sidecar/policy.go b/internal/gcs-sidecar/policy.go new file mode 100644 index 0000000000..13b96ce64d --- /dev/null +++ b/internal/gcs-sidecar/policy.go @@ -0,0 +1,19 @@ +//go:build windows +// +build windows + +package bridge + +import ( + oci "github.com/opencontainers/runtime-spec/specs-go" +) + +// DefaultCRIMounts returns default mounts added to windows spec by containerD. +func DefaultCRIMounts() []oci.Mount { + return []oci.Mount{} +} + +// DefaultCRIPrivilegedMounts returns a slice of mounts which are added to the +// windows container spec when a container runs in a privileged mode. +func DefaultCRIPrivilegedMounts() []oci.Mount { + return []oci.Mount{} +} diff --git a/internal/gcs-sidecar/uvm.go b/internal/gcs-sidecar/uvm.go new file mode 100644 index 0000000000..f6cf07ec14 --- /dev/null +++ b/internal/gcs-sidecar/uvm.go @@ -0,0 +1,171 @@ +//go:build windows +// +build windows + +package bridge + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/Microsoft/hcsshim/hcn" + "github.com/Microsoft/hcsshim/internal/bridgeutils/commonutils" + "github.com/Microsoft/hcsshim/internal/gcs/prot" + hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" + "github.com/Microsoft/hcsshim/internal/log" + "github.com/Microsoft/hcsshim/internal/oc" + "github.com/Microsoft/hcsshim/internal/protocol/guestrequest" + "github.com/Microsoft/hcsshim/internal/protocol/guestresource" + "github.com/Microsoft/hcsshim/pkg/securitypolicy" + "github.com/pkg/errors" +) + +func modifyMappedVirtualDisk( + ctx context.Context, + rt guestrequest.RequestType, + mvd *guestresource.WCOWMappedVirtualDisk, + securityPolicy securitypolicy.SecurityPolicyEnforcer, +) (err error) { + switch rt { + case guestrequest.RequestTypeAdd: + // TODO: Modify and update this with verified Cims API + return securityPolicy.EnforceDeviceMountPolicy(ctx, mvd.ContainerPath, "hash") + case guestrequest.RequestTypeRemove: + log.G(ctx).Tracef("enforcing mount_device in mappedvirtualdisk") + // TODO: Modify and update this with verified Cims API + return securityPolicy.EnforceDeviceUnmountPolicy(ctx, mvd.ContainerPath) + default: + return newInvalidRequestTypeError(rt) + } +} + +func modifyCombinedLayers( + ctx context.Context, + containerID string, + rt guestrequest.RequestType, + cl guestresource.WCOWCombinedLayers, + securityPolicy securitypolicy.SecurityPolicyEnforcer, +) (err error) { + switch rt { + case guestrequest.RequestTypeAdd: + layerPaths := make([]string, len(cl.Layers)) + for i, layer := range cl.Layers { + layerPaths[i] = layer.Path + } + //TODO: Remove this when there is verified Cimfs API + return securityPolicy.EnforceOverlayMountPolicy(ctx, containerID, layerPaths, cl.ContainerRootPath) + case guestrequest.RequestTypeRemove: + return securityPolicy.EnforceOverlayUnmountPolicy(ctx, cl.ContainerRootPath) + default: + return newInvalidRequestTypeError(rt) + } +} + +func newInvalidRequestTypeError(rt guestrequest.RequestType) error { + return errors.Errorf("the RequestType %q is not supported", rt) +} + +func unmarshalContainerModifySettings(req *request) (_ *prot.ContainerModifySettings, err error) { + ctx, span := oc.StartSpan(req.ctx, "sidecar::unmarshalContainerModifySettings") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + + var r prot.ContainerModifySettings + var requestRawSettings json.RawMessage + r.Request = &requestRawSettings + if err := commonutils.UnmarshalJSONWithHresult(req.message, &r); err != nil { + return nil, errors.Wrap(err, "failed to unmarshal rpcModifySettings") + } + + var modifyGuestSettingsRequest guestrequest.ModificationRequest + var rawGuestRequest json.RawMessage + modifyGuestSettingsRequest.Settings = &rawGuestRequest + if err := commonutils.UnmarshalJSONWithHresult(requestRawSettings, &modifyGuestSettingsRequest); err != nil { + return nil, errors.Wrap(err, "invalid rpcModifySettings ModificationRequest") + } + + if modifyGuestSettingsRequest.RequestType == "" { + modifyGuestSettingsRequest.RequestType = guestrequest.RequestTypeAdd + } + + if modifyGuestSettingsRequest.ResourceType != "" { + switch modifyGuestSettingsRequest.ResourceType { + case guestresource.ResourceTypeCWCOWCombinedLayers: + settings := &guestresource.CWCOWCombinedLayers{} + if err := commonutils.UnmarshalJSONWithHresult(rawGuestRequest, settings); err != nil { + return nil, errors.Wrap(err, "invalid ResourceTypeCWCOWCombinedLayers request") + } + modifyGuestSettingsRequest.Settings = settings + + case guestresource.ResourceTypeCombinedLayers: + settings := &guestresource.WCOWCombinedLayers{} + if err := commonutils.UnmarshalJSONWithHresult(rawGuestRequest, settings); err != nil { + return nil, errors.Wrap(err, "invalid ResourceTypeCombinedLayers request") + } + modifyGuestSettingsRequest.Settings = settings + + case guestresource.ResourceTypeNetworkNamespace: + settings := &hcn.HostComputeNamespace{} + if err := commonutils.UnmarshalJSONWithHresult(rawGuestRequest, settings); err != nil { + return nil, errors.Wrap(err, "invalid ResourceTypeNetworkNamespace request") + } + modifyGuestSettingsRequest.Settings = settings + + case guestresource.ResourceTypeNetwork: + settings := &guestrequest.NetworkModifyRequest{} + if err := commonutils.UnmarshalJSONWithHresult(rawGuestRequest, settings); err != nil { + return nil, errors.Wrap(err, "invalid ResourceTypeNetwork request") + } + modifyGuestSettingsRequest.Settings = settings + + case guestresource.ResourceTypeMappedVirtualDisk: + wcowMappedVirtualDisk := &guestresource.WCOWMappedVirtualDisk{} + if err := commonutils.UnmarshalJSONWithHresult(rawGuestRequest, wcowMappedVirtualDisk); err != nil { + return nil, errors.Wrap(err, "invalid ResourceTypeMappedVirtualDisk request") + } + modifyGuestSettingsRequest.Settings = wcowMappedVirtualDisk + + case guestresource.ResourceTypeHvSocket: + hvSocketAddress := &hcsschema.HvSocketAddress{} + if err := commonutils.UnmarshalJSONWithHresult(rawGuestRequest, hvSocketAddress); err != nil { + return nil, errors.Wrap(err, "invalid ResourceTypeHvSocket request") + } + modifyGuestSettingsRequest.Settings = hvSocketAddress + + case guestresource.ResourceTypeMappedDirectory: + settings := &hcsschema.MappedDirectory{} + if err := commonutils.UnmarshalJSONWithHresult(rawGuestRequest, settings); err != nil { + return nil, errors.Wrap(err, "invalid ResourceTypeMappedDirectory request") + } + modifyGuestSettingsRequest.Settings = settings + + case guestresource.ResourceTypeSecurityPolicy: + securityPolicyRequest := &guestresource.WCOWConfidentialOptions{} + if err := commonutils.UnmarshalJSONWithHresult(rawGuestRequest, securityPolicyRequest); err != nil { + return nil, errors.Wrap(err, "invalid ResourceTypeSecurityPolicy request") + } + modifyGuestSettingsRequest.Settings = securityPolicyRequest + + case guestresource.ResourceTypeMappedVirtualDiskForContainerScratch: + wcowMappedVirtualDisk := &guestresource.WCOWMappedVirtualDisk{} + if err := commonutils.UnmarshalJSONWithHresult(rawGuestRequest, wcowMappedVirtualDisk); err != nil { + return nil, errors.Wrap(err, "invalid ResourceTypeMappedVirtualDiskForContainerScratch request") + } + modifyGuestSettingsRequest.Settings = wcowMappedVirtualDisk + + case guestresource.ResourceTypeWCOWBlockCims: + wcowBlockCimMounts := &guestresource.WCOWBlockCIMMounts{} + if err := commonutils.UnmarshalJSONWithHresult(rawGuestRequest, wcowBlockCimMounts); err != nil { + return nil, errors.Wrap(err, "invalid ResourceTypeWCOWBlockCims request") + } + modifyGuestSettingsRequest.Settings = wcowBlockCimMounts + + default: + // Invalid request + log.G(ctx).Errorf("Invald modifySettingsRequest: %v", modifyGuestSettingsRequest.ResourceType) + return nil, fmt.Errorf("invald modifySettingsRequest") + } + } + r.Request = &modifyGuestSettingsRequest + return &r, nil +} diff --git a/internal/gcs/bridge.go b/internal/gcs/bridge.go index 0aa9d54536..65ada4ed45 100644 --- a/internal/gcs/bridge.go +++ b/internal/gcs/bridge.go @@ -19,33 +19,22 @@ import ( "go.opencensus.io/trace" "golang.org/x/sys/windows" + "github.com/Microsoft/hcsshim/internal/gcs/prot" "github.com/Microsoft/hcsshim/internal/log" "github.com/Microsoft/hcsshim/internal/oc" ) -const ( - hdrSize = 16 - hdrOffType = 0 - hdrOffSize = 4 - hdrOffID = 8 - - // maxMsgSize is the maximum size of an incoming message. This is not - // enforced by the guest today but some maximum must be set to avoid - // unbounded allocations. - maxMsgSize = 0x10000 -) - type requestMessage interface { - Base() *requestBase + Base() *prot.RequestBase } type responseMessage interface { - Base() *responseBase + Base() *prot.ResponseBase } // rpc represents an outstanding rpc request to the guest type rpc struct { - proc rpcProc + proc prot.RPCProc id int64 req requestMessage resp responseMessage @@ -78,7 +67,7 @@ const ( bridgeFailureTimeout = time.Minute * 5 ) -type notifyFunc func(*containerNotification) error +type notifyFunc func(*prot.ContainerNotification) error // newBridge returns a bridge on `conn`. It calls `notify` when a // notification message arrives from the guest. It logs transport errors and @@ -141,7 +130,7 @@ func (brdg *bridge) Wait() error { // AsyncRPC sends an RPC request to the guest but does not wait for a response. // If the message cannot be sent before the context is done, then an error is // returned. -func (brdg *bridge) AsyncRPC(ctx context.Context, proc rpcProc, req requestMessage, resp responseMessage) (*rpc, error) { +func (brdg *bridge) AsyncRPC(ctx context.Context, proc prot.RPCProc, req requestMessage, resp responseMessage) (*rpc, error) { call := &rpc{ ch: make(chan struct{}), proc: proc, @@ -222,7 +211,7 @@ func (call *rpc) Wait() { // If allowCancel is set and the context becomes done, returns an error without // waiting for a response. Avoid this on messages that are not idempotent or // otherwise safe to ignore the response of. -func (brdg *bridge) RPC(ctx context.Context, proc rpcProc, req requestMessage, resp responseMessage, allowCancel bool) error { +func (brdg *bridge) RPC(ctx context.Context, proc prot.RPCProc, req requestMessage, resp responseMessage, allowCancel bool) error { call, err := brdg.AsyncRPC(ctx, proc, req, resp) if err != nil { return err @@ -259,26 +248,26 @@ func (brdg *bridge) recvLoopRoutine() { } } -func readMessage(r io.Reader) (int64, msgType, []byte, error) { +func readMessage(r io.Reader) (int64, prot.MsgType, []byte, error) { _, span := oc.StartSpan(context.Background(), "bridge receive read message", oc.WithClientSpanKind) defer span.End() - var h [hdrSize]byte + var h [prot.HdrSize]byte _, err := io.ReadFull(r, h[:]) if err != nil { return 0, 0, nil, err } - typ := msgType(binary.LittleEndian.Uint32(h[hdrOffType:])) - n := binary.LittleEndian.Uint32(h[hdrOffSize:]) - id := int64(binary.LittleEndian.Uint64(h[hdrOffID:])) + typ := prot.MsgType(binary.LittleEndian.Uint32(h[prot.HdrOffType:])) + n := binary.LittleEndian.Uint32(h[prot.HdrOffSize:]) + id := int64(binary.LittleEndian.Uint64(h[prot.HdrOffID:])) span.AddAttributes( trace.StringAttribute("type", typ.String()), trace.Int64Attribute("message-id", id)) - if n < hdrSize || n > maxMsgSize { + if n < prot.HdrSize || n > prot.MaxMsgSize { return 0, 0, nil, fmt.Errorf("invalid message size %d", n) } - n -= hdrSize + n -= prot.HdrSize b := make([]byte, n) _, err = io.ReadFull(r, b) if err != nil { @@ -309,8 +298,8 @@ func (brdg *bridge) recvLoop() error { "type": typ.String(), "message-id": id}).Trace("bridge receive") - switch typ & msgTypeMask { - case msgTypeResponse: + switch typ & prot.MsgTypeMask { + case prot.MsgTypeResponse: // Find the request associated with this response. brdg.mu.Lock() call := brdg.rpcs[id] @@ -342,11 +331,11 @@ func (brdg *bridge) recvLoop() error { return err } - case msgTypeNotify: - if typ != notifyContainer|msgTypeNotify { + case prot.MsgTypeNotify: + if typ != prot.NotifyContainer|prot.MsgTypeNotify { return fmt.Errorf("bridge received unknown unknown notification message %s", typ) } - var ntf containerNotification + var ntf prot.ContainerNotification ntf.ResultInfo.Value = &json.RawMessage{} err := json.Unmarshal(b, &ntf) if err != nil { @@ -381,7 +370,7 @@ func (brdg *bridge) sendLoop() { } } -func (brdg *bridge) writeMessage(buf *bytes.Buffer, enc *json.Encoder, typ msgType, id int64, req interface{}) error { +func (brdg *bridge) writeMessage(buf *bytes.Buffer, enc *json.Encoder, typ prot.MsgType, id int64, req interface{}) error { var err error _, span := oc.StartSpan(context.Background(), "bridge send", oc.WithClientSpanKind) defer span.End() @@ -391,24 +380,24 @@ func (brdg *bridge) writeMessage(buf *bytes.Buffer, enc *json.Encoder, typ msgTy trace.Int64Attribute("message-id", id)) // Prepare the buffer with the message. - var h [hdrSize]byte - binary.LittleEndian.PutUint32(h[hdrOffType:], uint32(typ)) - binary.LittleEndian.PutUint64(h[hdrOffID:], uint64(id)) + var h [prot.HdrSize]byte + binary.LittleEndian.PutUint32(h[prot.HdrOffType:], uint32(typ)) + binary.LittleEndian.PutUint64(h[prot.HdrOffID:], uint64(id)) buf.Write(h[:]) err = enc.Encode(req) if err != nil { return fmt.Errorf("bridge encode: %w", err) } // Update the message header with the size. - binary.LittleEndian.PutUint32(buf.Bytes()[hdrOffSize:], uint32(buf.Len())) + binary.LittleEndian.PutUint32(buf.Bytes()[prot.HdrOffSize:], uint32(buf.Len())) if brdg.log.Logger.GetLevel() > logrus.DebugLevel { - b := buf.Bytes()[hdrSize:] + b := buf.Bytes()[prot.HdrSize:] switch typ { // container environment vars are in rpCreate for linux; rpcExecuteProcess for windows - case msgType(rpcCreate) | msgTypeRequest: + case prot.MsgType(prot.RPCCreate) | prot.MsgTypeRequest: b, err = log.ScrubBridgeCreate(b) - case msgType(rpcExecuteProcess) | msgTypeRequest: + case prot.MsgType(prot.RPCExecuteProcess) | prot.MsgTypeRequest: b, err = log.ScrubBridgeExecProcess(b) } if err != nil { @@ -441,7 +430,7 @@ func (brdg *bridge) sendRPC(buf *bytes.Buffer, enc *json.Encoder, call *rpc) err brdg.rpcs[id] = call brdg.nextID++ brdg.mu.Unlock() - typ := msgType(call.proc) | msgTypeRequest + typ := prot.MsgType(call.proc) | prot.MsgTypeRequest err := brdg.writeMessage(buf, enc, typ, id, call.req) if err != nil { // Try to reclaim this request and fail it. diff --git a/internal/gcs/bridge_test.go b/internal/gcs/bridge_test.go index d6b3265c60..3da35e9d9d 100644 --- a/internal/gcs/bridge_test.go +++ b/internal/gcs/bridge_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/Microsoft/hcsshim/internal/gcs/prot" "github.com/sirupsen/logrus" ) @@ -33,7 +34,7 @@ func pipeConn() (*stitched, *stitched) { return &stitched{r1, w2}, &stitched{r2, w1} } -func sendMessage(t *testing.T, w io.Writer, typ msgType, id int64, msg []byte) { +func sendMessage(t *testing.T, w io.Writer, typ prot.MsgType, id int64, msg []byte) { t.Helper() var h [16]byte binary.LittleEndian.PutUint32(h[:], uint32(typ)) @@ -63,18 +64,18 @@ func reflector(t *testing.T, rw io.ReadWriteCloser, delay time.Duration) { return } time.Sleep(delay) // delay is used to test timeouts (when non-zero) - typ ^= msgTypeResponse ^ msgTypeRequest + typ ^= prot.MsgTypeResponse ^ prot.MsgTypeRequest sendMessage(t, rw, typ, id, msg) } } type testReq struct { - requestBase + prot.RequestBase X, Y int } type testResp struct { - responseBase + prot.ResponseBase X, Y int } @@ -92,7 +93,7 @@ func TestBridgeRPC(t *testing.T) { defer b.Close() req := testReq{X: 5} var resp testResp - err := b.RPC(context.Background(), rpcCreate, &req, &resp, false) + err := b.RPC(context.Background(), prot.RPCCreate, &req, &resp, false) if err != nil { t.Fatal(err) } @@ -107,7 +108,7 @@ func TestBridgeRPCResponseTimeout(t *testing.T) { b.Timeout = time.Millisecond * 100 req := testReq{X: 5} var resp testResp - err := b.RPC(context.Background(), rpcCreate, &req, &resp, false) + err := b.RPC(context.Background(), prot.RPCCreate, &req, &resp, false) if err == nil || !strings.Contains(err.Error(), "bridge closed") { t.Fatalf("expected bridge disconnection, got %s", err) } @@ -121,7 +122,7 @@ func TestBridgeRPCContextDone(t *testing.T) { defer cancel() req := testReq{X: 5} var resp testResp - err := b.RPC(ctx, rpcCreate, &req, &resp, true) + err := b.RPC(ctx, prot.RPCCreate, &req, &resp, true) if err != context.DeadlineExceeded { //nolint:errorlint t.Fatalf("expected deadline exceeded, got %s", err) } @@ -135,7 +136,7 @@ func TestBridgeRPCContextDoneNoCancel(t *testing.T) { defer cancel() req := testReq{X: 5} var resp testResp - err := b.RPC(ctx, rpcCreate, &req, &resp, false) + err := b.RPC(ctx, prot.RPCCreate, &req, &resp, false) if err == nil || !strings.Contains(err.Error(), "bridge closed") { t.Fatalf("expected bridge disconnection, got %s", err) } @@ -145,13 +146,13 @@ func TestBridgeRPCBridgeClosed(t *testing.T) { b := startReflectedBridge(t, 0) eerr := errors.New("forcibly terminated") b.kill(eerr) - err := b.RPC(context.Background(), rpcCreate, nil, nil, false) + err := b.RPC(context.Background(), prot.RPCCreate, nil, nil, false) if err != eerr { //nolint:errorlint t.Fatal("unexpected: ", err) } } -func sendJSON(t *testing.T, w io.Writer, typ msgType, id int64, msg interface{}) error { +func sendJSON(t *testing.T, w io.Writer, typ prot.MsgType, id int64, msg interface{}) error { t.Helper() msgb, err := json.Marshal(msg) if err != nil { @@ -161,7 +162,7 @@ func sendJSON(t *testing.T, w io.Writer, typ msgType, id int64, msg interface{}) return nil } -func notifyThroughBridge(t *testing.T, typ msgType, msg interface{}, fn notifyFunc) error { +func notifyThroughBridge(t *testing.T, typ prot.MsgType, msg interface{}, fn notifyFunc) error { t.Helper() s, c := pipeConn() b := newBridge(s, fn, logrus.NewEntry(logrus.StandardLogger())) @@ -176,9 +177,9 @@ func notifyThroughBridge(t *testing.T, typ msgType, msg interface{}, fn notifyFu } func TestBridgeNotify(t *testing.T) { - ntf := &containerNotification{Operation: "testing"} + ntf := &prot.ContainerNotification{Operation: "testing"} recvd := false - err := notifyThroughBridge(t, msgTypeNotify|notifyContainer, ntf, func(nntf *containerNotification) error { + err := notifyThroughBridge(t, prot.MsgTypeNotify|prot.NotifyContainer, ntf, func(nntf *prot.ContainerNotification) error { if !reflect.DeepEqual(ntf, nntf) { t.Errorf("%+v != %+v", ntf, nntf) } @@ -194,9 +195,9 @@ func TestBridgeNotify(t *testing.T) { } func TestBridgeNotifyFailure(t *testing.T) { - ntf := &containerNotification{Operation: "testing"} + ntf := &prot.ContainerNotification{Operation: "testing"} errMsg := "notify should have failed" - err := notifyThroughBridge(t, msgTypeNotify|notifyContainer, ntf, func(nntf *containerNotification) error { + err := notifyThroughBridge(t, prot.MsgTypeNotify|prot.NotifyContainer, ntf, func(nntf *prot.ContainerNotification) error { return errors.New(errMsg) }) if err == nil || !strings.Contains(err.Error(), errMsg) { diff --git a/internal/gcs/container.go b/internal/gcs/container.go index a64408b834..549abd35a2 100644 --- a/internal/gcs/container.go +++ b/internal/gcs/container.go @@ -9,6 +9,7 @@ import ( "time" "github.com/Microsoft/hcsshim/internal/cow" + "github.com/Microsoft/hcsshim/internal/gcs/prot" "github.com/Microsoft/hcsshim/internal/hcs/schema1" hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" "github.com/Microsoft/hcsshim/internal/log" @@ -53,12 +54,12 @@ func (gc *GuestConnection) CreateContainer(ctx context.Context, cid string, conf if err != nil { return nil, err } - req := containerCreate{ - requestBase: makeRequest(ctx, cid), - ContainerConfig: anyInString{config}, + req := prot.ContainerCreate{ + RequestBase: makeRequest(ctx, cid), + ContainerConfig: prot.AnyInString{Value: config}, } - var resp containerCreateResponse - err = gc.brdg.RPC(ctx, rpcCreate, &req, &resp, false) + var resp prot.ContainerCreateResponse + err = gc.brdg.RPC(ctx, prot.RPCCreate, &req, &resp, false) if err != nil { return nil, err } @@ -129,27 +130,27 @@ func (c *Container) Modify(ctx context.Context, config interface{}) (err error) defer func() { oc.SetSpanStatus(span, err) }() span.AddAttributes(trace.StringAttribute("cid", c.id)) - req := containerModifySettings{ - requestBase: makeRequest(ctx, c.id), + req := prot.ContainerModifySettings{ + RequestBase: makeRequest(ctx, c.id), Request: config, } - var resp responseBase - return c.gc.brdg.RPC(ctx, rpcModifySettings, &req, &resp, false) + var resp prot.ResponseBase + return c.gc.brdg.RPC(ctx, prot.RPCModifySettings, &req, &resp, false) } -// Properties returns the requested container properties targeting a V1 schema container. +// Properties returns the requested container properties targeting a V1 schema prot.Container. func (c *Container) Properties(ctx context.Context, types ...schema1.PropertyType) (_ *schema1.ContainerProperties, err error) { ctx, span := oc.StartSpan(ctx, "gcs::Container::Properties", oc.WithClientSpanKind) defer span.End() defer func() { oc.SetSpanStatus(span, err) }() span.AddAttributes(trace.StringAttribute("cid", c.id)) - req := containerGetProperties{ - requestBase: makeRequest(ctx, c.id), - Query: containerPropertiesQuery{PropertyTypes: types}, + req := prot.ContainerGetProperties{ + RequestBase: makeRequest(ctx, c.id), + Query: prot.ContainerPropertiesQuery{PropertyTypes: types}, } - var resp containerGetPropertiesResponse - err = c.gc.brdg.RPC(ctx, rpcGetProperties, &req, &resp, true) + var resp prot.ContainerGetPropertiesResponse + err = c.gc.brdg.RPC(ctx, prot.RPCGetProperties, &req, &resp, true) if err != nil { return nil, err } @@ -163,12 +164,12 @@ func (c *Container) PropertiesV2(ctx context.Context, types ...hcsschema.Propert defer func() { oc.SetSpanStatus(span, err) }() span.AddAttributes(trace.StringAttribute("cid", c.id)) - req := containerGetPropertiesV2{ - requestBase: makeRequest(ctx, c.id), - Query: containerPropertiesQueryV2{PropertyTypes: types}, + req := prot.ContainerGetPropertiesV2{ + RequestBase: makeRequest(ctx, c.id), + Query: prot.ContainerPropertiesQueryV2{PropertyTypes: types}, } - var resp containerGetPropertiesResponseV2 - err = c.gc.brdg.RPC(ctx, rpcGetProperties, &req, &resp, true) + var resp prot.ContainerGetPropertiesResponseV2 + err = c.gc.brdg.RPC(ctx, prot.RPCGetProperties, &req, &resp, true) if err != nil { return nil, err } @@ -183,13 +184,13 @@ func (c *Container) Start(ctx context.Context) (err error) { span.AddAttributes(trace.StringAttribute("cid", c.id)) req := makeRequest(ctx, c.id) - var resp responseBase - return c.gc.brdg.RPC(ctx, rpcStart, &req, &resp, false) + var resp prot.ResponseBase + return c.gc.brdg.RPC(ctx, prot.RPCStart, &req, &resp, false) } -func (c *Container) shutdown(ctx context.Context, proc rpcProc) error { +func (c *Container) shutdown(ctx context.Context, proc prot.RPCProc) error { req := makeRequest(ctx, c.id) - var resp responseBase + var resp prot.ResponseBase err := c.gc.brdg.RPC(ctx, proc, &req, &resp, true) if err != nil { if uint32(resp.Result) != hrComputeSystemDoesNotExist { @@ -215,7 +216,7 @@ func (c *Container) Shutdown(ctx context.Context) (err error) { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - return c.shutdown(ctx, rpcShutdownGraceful) + return c.shutdown(ctx, prot.RPCShutdownGraceful) } // Terminate sends a forceful terminate request to the container. The container @@ -229,7 +230,7 @@ func (c *Container) Terminate(ctx context.Context) (err error) { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - return c.shutdown(ctx, rpcShutdownForced) + return c.shutdown(ctx, prot.RPCShutdownForced) } func (c *Container) WaitChannel() <-chan struct{} { diff --git a/internal/gcs/guestconnection.go b/internal/gcs/guestconnection.go index fe974b5c17..9107dd4d3b 100644 --- a/internal/gcs/guestconnection.go +++ b/internal/gcs/guestconnection.go @@ -16,6 +16,7 @@ import ( "github.com/Microsoft/go-winio" "github.com/Microsoft/go-winio/pkg/guid" "github.com/Microsoft/hcsshim/internal/cow" + "github.com/Microsoft/hcsshim/internal/gcs/prot" hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" "github.com/Microsoft/hcsshim/internal/log" "github.com/Microsoft/hcsshim/internal/logfields" @@ -28,7 +29,7 @@ import ( const ( protocolVersion = 4 - firstIoChannelVsockPort = LinuxGcsVsockPort + 1 + firstIoChannelVsockPort = prot.LinuxGcsVsockPort + 1 nullContainerID = "00000000-0000-0000-0000-000000000000" ) @@ -117,12 +118,12 @@ func (gc *GuestConnection) Protocol() uint32 { // isColdStart should be true when the UVM is being connected to for the first time post-boot. // It should be false for subsequent connections (e.g. if reconnecting to an existing UVM). func (gc *GuestConnection) connect(ctx context.Context, isColdStart bool, initGuestState *InitialGuestState) (err error) { - req := negotiateProtocolRequest{ + req := prot.NegotiateProtocolRequest{ MinimumVersion: protocolVersion, MaximumVersion: protocolVersion, } - var resp negotiateProtocolResponse - err = gc.brdg.RPC(ctx, rpcNegotiateProtocol, &req, &resp, true) + var resp prot.NegotiateProtocolResponse + err = gc.brdg.RPC(ctx, prot.RPCNegotiateProtocol, &req, &resp, true) if err != nil { return err } @@ -141,25 +142,25 @@ func (gc *GuestConnection) connect(ctx context.Context, isColdStart bool, initGu } if isColdStart && resp.Capabilities.SendHostCreateMessage { - conf := &uvmConfig{ + conf := &prot.UvmConfig{ SystemType: "Container", } if initGuestState != nil && initGuestState.Timezone != nil { conf.TimeZoneInformation = initGuestState.Timezone } - createReq := containerCreate{ - requestBase: makeRequest(ctx, nullContainerID), - ContainerConfig: anyInString{conf}, + createReq := prot.ContainerCreate{ + RequestBase: makeRequest(ctx, nullContainerID), + ContainerConfig: prot.AnyInString{Value: conf}, } - var createResp responseBase - err = gc.brdg.RPC(ctx, rpcCreate, &createReq, &createResp, true) + var createResp prot.ResponseBase + err = gc.brdg.RPC(ctx, prot.RPCCreate, &createReq, &createResp, true) if err != nil { return err } if resp.Capabilities.SendHostStartMessage { startReq := makeRequest(ctx, nullContainerID) - var startResp responseBase - err = gc.brdg.RPC(ctx, rpcStart, &startReq, &startResp, true) + var startResp prot.ResponseBase + err = gc.brdg.RPC(ctx, prot.RPCStart, &startReq, &startResp, true) if err != nil { return err } @@ -175,12 +176,12 @@ func (gc *GuestConnection) Modify(ctx context.Context, settings interface{}) (er defer span.End() defer func() { oc.SetSpanStatus(span, err) }() - req := containerModifySettings{ - requestBase: makeRequest(ctx, nullContainerID), + req := prot.ContainerModifySettings{ + RequestBase: makeRequest(ctx, nullContainerID), Request: settings, } - var resp responseBase - return gc.brdg.RPC(ctx, rpcModifySettings, &req, &resp, false) + var resp prot.ResponseBase + return gc.brdg.RPC(ctx, prot.RPCModifySettings, &req, &resp, false) } func (gc *GuestConnection) DumpStacks(ctx context.Context) (response string, err error) { @@ -188,11 +189,11 @@ func (gc *GuestConnection) DumpStacks(ctx context.Context) (response string, err defer span.End() defer func() { oc.SetSpanStatus(span, err) }() - req := dumpStacksRequest{ - requestBase: makeRequest(ctx, nullContainerID), + req := prot.DumpStacksRequest{ + RequestBase: makeRequest(ctx, nullContainerID), } - var resp dumpStacksResponse - err = gc.brdg.RPC(ctx, rpcDumpStacks, &req, &resp, false) + var resp prot.DumpStacksResponse + err = gc.brdg.RPC(ctx, prot.RPCDumpStacks, &req, &resp, false) return resp.GuestStacks, err } @@ -202,11 +203,11 @@ func (gc *GuestConnection) DeleteContainerState(ctx context.Context, cid string) defer func() { oc.SetSpanStatus(span, err) }() span.AddAttributes(trace.StringAttribute("cid", cid)) - req := deleteContainerStateRequest{ - requestBase: makeRequest(ctx, cid), + req := prot.DeleteContainerStateRequest{ + RequestBase: makeRequest(ctx, cid), } - var resp responseBase - return gc.brdg.RPC(ctx, rpcDeleteContainerState, &req, &resp, false) + var resp prot.ResponseBase + return gc.brdg.RPC(ctx, prot.RPCDeleteContainerState, &req, &resp, false) } // Close terminates the guest connection. It is undefined to call any other @@ -263,7 +264,7 @@ func (gc *GuestConnection) requestNotify(cid string, ch chan struct{}) error { return nil } -func (gc *GuestConnection) notify(ntf *containerNotification) error { +func (gc *GuestConnection) notify(ntf *prot.ContainerNotification) error { cid := ntf.ContainerID gc.mu.Lock() ch := gc.notifyChs[cid] @@ -287,14 +288,14 @@ func (gc *GuestConnection) clearNotifies() { } } -func makeRequest(ctx context.Context, cid string) requestBase { - r := requestBase{ +func makeRequest(ctx context.Context, cid string) prot.RequestBase { + r := prot.RequestBase{ ContainerID: cid, } span := trace.FromContext(ctx) if span != nil { sc := span.SpanContext() - r.OpenCensusSpanContext = &ocspancontext{ + r.OpenCensusSpanContext = &prot.Ocspancontext{ TraceID: hex.EncodeToString(sc.TraceID[:]), SpanID: hex.EncodeToString(sc.SpanID[:]), TraceOptions: uint32(sc.TraceOptions), diff --git a/internal/gcs/guestconnection_test.go b/internal/gcs/guestconnection_test.go index facb0dd34b..6a72cb8a3f 100644 --- a/internal/gcs/guestconnection_test.go +++ b/internal/gcs/guestconnection_test.go @@ -21,6 +21,7 @@ import ( "go.opencensus.io/trace" "go.opencensus.io/trace/tracestate" + "github.com/Microsoft/hcsshim/internal/gcs/prot" "github.com/Microsoft/hcsshim/internal/oc" ) @@ -55,24 +56,24 @@ func simpleGcsLoop(t *testing.T, rw io.ReadWriter) error { } return err } - switch proc := rpcProc(typ &^ msgTypeRequest); proc { - case rpcNegotiateProtocol: - err := sendJSON(t, rw, msgTypeResponse|msgType(proc), id, &negotiateProtocolResponse{ + switch proc := prot.RPCProc(typ &^ prot.MsgTypeRequest); proc { + case prot.RPCNegotiateProtocol: + err := sendJSON(t, rw, prot.MsgTypeResponse|prot.MsgType(proc), id, &prot.NegotiateProtocolResponse{ Version: protocolVersion, - Capabilities: gcsCapabilities{ + Capabilities: prot.GcsCapabilities{ RuntimeOsType: "linux", }, }) if err != nil { return err } - case rpcCreate: - err := sendJSON(t, rw, msgTypeResponse|msgType(proc), id, &containerCreateResponse{}) + case prot.RPCCreate: + err := sendJSON(t, rw, prot.MsgTypeResponse|prot.MsgType(proc), id, &prot.ContainerCreateResponse{}) if err != nil { return err } - case rpcExecuteProcess: - var req containerExecuteProcess + case prot.RPCExecuteProcess: + var req prot.ContainerExecuteProcess var params baseProcessParams req.Settings.ProcessParameters.Value = ¶ms err := json.Unmarshal(b, &req) @@ -111,27 +112,27 @@ func simpleGcsLoop(t *testing.T, rw io.ReadWriter) error { stdout.Close() }() } - err = sendJSON(t, rw, msgTypeResponse|msgType(proc), id, &containerExecuteProcessResponse{ + err = sendJSON(t, rw, prot.MsgTypeResponse|prot.MsgType(proc), id, &prot.ContainerExecuteProcessResponse{ ProcessID: 42, }) if err != nil { return err } - case rpcWaitForProcess: + case prot.RPCWaitForProcess: // nothing - case rpcShutdownForced: - var req requestBase + case prot.RPCShutdownForced: + var req prot.RequestBase err = json.Unmarshal(b, &req) if err != nil { return err } - err = sendJSON(t, rw, msgTypeResponse|msgType(proc), id, &responseBase{}) + err = sendJSON(t, rw, prot.MsgTypeResponse|prot.MsgType(proc), id, &prot.ResponseBase{}) if err != nil { return err } time.Sleep(50 * time.Millisecond) - err = sendJSON(t, rw, msgType(msgTypeNotify|notifyContainer), 0, &containerNotification{ - requestBase: requestBase{ + err = sendJSON(t, rw, prot.MsgType(prot.MsgTypeNotify|prot.NotifyContainer), 0, &prot.ContainerNotification{ + RequestBase: prot.RequestBase{ ContainerID: req.ContainerID, }, }) diff --git a/internal/gcs/process.go b/internal/gcs/process.go index 87c5c29ae4..91d3a87faa 100644 --- a/internal/gcs/process.go +++ b/internal/gcs/process.go @@ -12,6 +12,7 @@ import ( "github.com/Microsoft/go-winio" "github.com/Microsoft/hcsshim/internal/cow" + "github.com/Microsoft/hcsshim/internal/gcs/prot" "github.com/Microsoft/hcsshim/internal/log" "github.com/Microsoft/hcsshim/internal/logfields" "github.com/Microsoft/hcsshim/internal/oc" @@ -29,7 +30,7 @@ type Process struct { cid string id uint32 waitCall *rpc - waitResp containerWaitForProcessResponse + waitResp prot.ContainerWaitForProcessResponse stdin, stdout, stderr *ioChannel stdinCloseWriteOnce sync.Once stdinCloseWriteErr error @@ -52,10 +53,10 @@ func (gc *GuestConnection) exec(ctx context.Context, cid string, params interfac return nil, err } - req := containerExecuteProcess{ - requestBase: makeRequest(ctx, cid), - Settings: executeProcessSettings{ - ProcessParameters: anyInString{params}, + req := prot.ContainerExecuteProcess{ + RequestBase: makeRequest(ctx, cid), + Settings: prot.ExecuteProcessSettings{ + ProcessParameters: prot.AnyInString{Value: params}, }, } @@ -68,8 +69,8 @@ func (gc *GuestConnection) exec(ctx context.Context, cid string, params interfac // Construct the stdio channels. Windows guests expect hvsock service IDs // instead of vsock ports. - var hvsockSettings executeProcessStdioRelaySettings - var vsockSettings executeProcessVsockStdioRelaySettings + var hvsockSettings prot.ExecuteProcessStdioRelaySettings + var vsockSettings prot.ExecuteProcessVsockStdioRelaySettings if gc.os == "windows" { req.Settings.StdioRelaySettings = &hvsockSettings } else { @@ -100,20 +101,20 @@ func (gc *GuestConnection) exec(ctx context.Context, cid string, params interfac hvsockSettings.StdErr = &g } - var resp containerExecuteProcessResponse - err = gc.brdg.RPC(ctx, rpcExecuteProcess, &req, &resp, false) + var resp prot.ContainerExecuteProcessResponse + err = gc.brdg.RPC(ctx, prot.RPCExecuteProcess, &req, &resp, false) if err != nil { return nil, err } p.id = resp.ProcessID log.G(ctx).WithField("pid", p.id).Debug("created process pid") // Start a wait message. - waitReq := containerWaitForProcess{ - requestBase: makeRequest(ctx, cid), + waitReq := prot.ContainerWaitForProcess{ + RequestBase: makeRequest(ctx, cid), ProcessID: p.id, TimeoutInMs: 0xffffffff, } - p.waitCall, err = gc.brdg.AsyncRPC(ctx, rpcWaitForProcess, &waitReq, &p.waitResp) + p.waitCall, err = gc.brdg.AsyncRPC(ctx, prot.RPCWaitForProcess, &waitReq, &p.waitResp) if err != nil { return nil, fmt.Errorf("failed to wait on process, leaking process: %w", err) } @@ -220,14 +221,14 @@ func (p *Process) ResizeConsole(ctx context.Context, width, height uint16) (err trace.StringAttribute("cid", p.cid), trace.Int64Attribute("pid", int64(p.id))) - req := containerResizeConsole{ - requestBase: makeRequest(ctx, p.cid), + req := prot.ContainerResizeConsole{ + RequestBase: makeRequest(ctx, p.cid), ProcessID: p.id, Height: height, Width: width, } - var resp responseBase - return p.gc.brdg.RPC(ctx, rpcResizeConsole, &req, &resp, true) + var resp prot.ResponseBase + return p.gc.brdg.RPC(ctx, prot.RPCResizeConsole, &req, &resp, true) } // Signal sends a signal to the process, returning whether it was delivered. @@ -239,15 +240,15 @@ func (p *Process) Signal(ctx context.Context, options interface{}) (_ bool, err trace.StringAttribute("cid", p.cid), trace.Int64Attribute("pid", int64(p.id))) - req := containerSignalProcess{ - requestBase: makeRequest(ctx, p.cid), + req := prot.ContainerSignalProcess{ + RequestBase: makeRequest(ctx, p.cid), ProcessID: p.id, Options: options, } - var resp responseBase + var resp prot.ResponseBase // FUTURE: SIGKILL is idempotent and can safely be cancelled, but this interface // does currently make it easy to determine what signal is being sent. - err = p.gc.brdg.RPC(ctx, rpcSignalProcess, &req, &resp, false) + err = p.gc.brdg.RPC(ctx, prot.RPCSignalProcess, &req, &resp, false) if err != nil { if uint32(resp.Result) != hrNotFound { return false, err diff --git a/internal/gcs/prot/protocol.go b/internal/gcs/prot/protocol.go new file mode 100644 index 0000000000..6be8f95482 --- /dev/null +++ b/internal/gcs/prot/protocol.go @@ -0,0 +1,399 @@ +//go:build windows + +package prot + +import ( + "encoding/json" + "fmt" + "strconv" + + "github.com/Microsoft/go-winio/pkg/guid" + "github.com/Microsoft/hcsshim/internal/bridgeutils/commonutils" + "github.com/Microsoft/hcsshim/internal/hcs/schema1" + hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" +) + +const ( + HdrSize = 16 + HdrOffType = 0 + HdrOffSize = 4 + HdrOffID = 8 + + // maxMsgSize is the maximum size of an incoming message. This is not + // enforced by the guest today but some maximum must be set to avoid + // unbounded allocations. + MaxMsgSize = 0x10000 + + // LinuxGcsVsockPort is the vsock port number that the Linux GCS will + // connect to. + LinuxGcsVsockPort = 0x40000000 +) + +// e0e16197-dd56-4a10-9195-5ee7a155a838 +var HvGUIDLoopback = guid.GUID{ + Data1: 0xe0e16197, + Data2: 0xdd56, + Data3: 0x4a10, + Data4: [8]uint8{0x91, 0x95, 0x5e, 0xe7, 0xa1, 0x55, 0xa8, 0x38}, +} + +// a42e7cda-d03f-480c-9cc2-a4de20abb878 +var HvGUIDParent = guid.GUID{ + Data1: 0xa42e7cda, + Data2: 0xd03f, + Data3: 0x480c, + Data4: [8]uint8{0x9c, 0xc2, 0xa4, 0xde, 0x20, 0xab, 0xb8, 0x78}, +} + +// WindowsGcsHvsockServiceID is the hvsock service ID that the Windows GCS +// will connect to. +var WindowsGcsHvsockServiceID = guid.GUID{ + Data1: 0xacef5661, + Data2: 0x84a1, + Data3: 0x4e44, + Data4: [8]uint8{0x85, 0x6b, 0x62, 0x45, 0xe6, 0x9f, 0x46, 0x20}, +} + +// WindowsSidecarGcsHvsockServiceID is the hvsock service ID that the Windows GCS +// sidecar will connect to. This is only used in the confidential mode. +var WindowsSidecarGcsHvsockServiceID = guid.GUID{ + Data1: 0xae8da506, + Data2: 0xa019, + Data3: 0x4553, + Data4: [8]uint8{0xa5, 0x2b, 0x90, 0x2b, 0xc0, 0xfa, 0x04, 0x11}, +} + +// WindowsGcsHvHostID is the hvsock address for the parent of the VM running the GCS +var WindowsGcsHvHostID = guid.GUID{ + Data1: 0x894cc2d6, + Data2: 0x9d79, + Data3: 0x424f, + Data4: [8]uint8{0x93, 0xfe, 0x42, 0x96, 0x9a, 0xe6, 0xd8, 0xd1}, +} + +type AnyInString struct { + Value interface{} +} + +func (a *AnyInString) MarshalText() ([]byte, error) { + return json.Marshal(a.Value) +} + +func (a *AnyInString) UnmarshalText(b []byte) error { + return json.Unmarshal(b, &a.Value) +} + +type RPCProc uint32 + +const ( + RPCCreate RPCProc = (iota+1)<<8 | 1 + RPCStart + RPCShutdownGraceful + RPCShutdownForced + RPCExecuteProcess + RPCWaitForProcess + RPCSignalProcess + RPCResizeConsole + RPCGetProperties + RPCModifySettings + RPCNegotiateProtocol + RPCDumpStacks + RPCDeleteContainerState + RPCUpdateContainer + RPCLifecycleNotification +) + +func (rpc RPCProc) String() string { + switch rpc { + case RPCCreate: + return "Create" + case RPCStart: + return "Start" + case RPCShutdownGraceful: + return "ShutdownGraceful" + case RPCShutdownForced: + return "ShutdownForced" + case RPCExecuteProcess: + return "ExecuteProcess" + case RPCWaitForProcess: + return "WaitForProcess" + case RPCSignalProcess: + return "SignalProcess" + case RPCResizeConsole: + return "ResizeConsole" + case RPCGetProperties: + return "GetProperties" + case RPCModifySettings: + return "ModifySettings" + case RPCNegotiateProtocol: + return "NegotiateProtocol" + case RPCDumpStacks: + return "DumpStacks" + case RPCDeleteContainerState: + return "DeleteContainerState" + case RPCUpdateContainer: + return "UpdateContainer" + case RPCLifecycleNotification: + return "LifecycleNotification" + default: + return "0x" + strconv.FormatUint(uint64(rpc), 16) + } +} + +type MsgType uint32 + +const ( + MsgTypeRequest MsgType = 0x10100000 + MsgTypeResponse MsgType = 0x20100000 + MsgTypeNotify MsgType = 0x30100000 + MsgTypeMask MsgType = 0xfff00000 + + NotifyContainer = 1<<8 | 1 +) + +func (typ MsgType) String() string { + var s string + switch typ & MsgTypeMask { + case MsgTypeRequest: + s = "Request(" + case MsgTypeResponse: + s = "Response(" + case MsgTypeNotify: + s = "Notify(" + switch typ - MsgTypeNotify { + case NotifyContainer: + s += "Container" + default: + s += fmt.Sprintf("%#x", uint32(typ)) + } + return s + ")" + default: + return fmt.Sprintf("%#x", uint32(typ)) + } + s += RPCProc(typ &^ MsgTypeMask).String() + return s + ")" +} + +// Ocspancontext is the internal JSON representation of the OpenCensus +// `trace.SpanContext` for fowarding to a GCS that supports it. +type Ocspancontext struct { + // TraceID is the `hex` encoded string of the OpenCensus + // `SpanContext.TraceID` to propagate to the guest. + TraceID string `json:",omitempty"` + // SpanID is the `hex` encoded string of the OpenCensus `SpanContext.SpanID` + // to propagate to the guest. + SpanID string `json:",omitempty"` + + // TraceOptions is the OpenCensus `SpanContext.TraceOptions` passed through + // to propagate to the guest. + TraceOptions uint32 `json:",omitempty"` + + // Tracestate is the `base64` encoded string of marshaling the OpenCensus + // `SpanContext.TraceState.Entries()` to JSON. + // + // If `SpanContext.Tracestate == nil || + // len(SpanContext.Tracestate.Entries()) == 0` this will be `""`. + Tracestate string `json:",omitempty"` +} + +type RequestBase struct { + ContainerID string `json:"ContainerId"` + ActivityID guid.GUID `json:"ActivityId"` + + // OpenCensusSpanContext is the encoded OpenCensus `trace.SpanContext` if + // set when making the request. + // + // NOTE: This is not a part of the protocol but because its a JSON protocol + // adding fields is a non-breaking change. If the guest supports it this is + // just additive context. + OpenCensusSpanContext *Ocspancontext `json:"ocsc,omitempty"` +} + +func (req *RequestBase) Base() *RequestBase { + return req +} + +type ResponseBase struct { + Result int32 // HResult + ErrorMessage string `json:",omitempty"` + ActivityID guid.GUID `json:"ActivityId,omitempty"` + ErrorRecords []commonutils.ErrorRecord `json:",omitempty"` +} + +func (resp *ResponseBase) Base() *ResponseBase { + return resp +} + +type NegotiateProtocolRequest struct { + RequestBase + MinimumVersion uint32 + MaximumVersion uint32 +} + +type NegotiateProtocolResponse struct { + ResponseBase + Version uint32 `json:",omitempty"` + Capabilities GcsCapabilities `json:",omitempty"` +} + +type DumpStacksRequest struct { + RequestBase +} + +type DumpStacksResponse struct { + ResponseBase + GuestStacks string +} + +type DeleteContainerStateRequest struct { + RequestBase +} + +type ContainerCreate struct { + RequestBase + ContainerConfig AnyInString +} + +type UvmConfig struct { + SystemType string // must be "Container" + TimeZoneInformation *hcsschema.TimeZoneInformation +} + +type ContainerNotification struct { + RequestBase + Type string // Compute.System.NotificationType + Operation string // Compute.System.ActiveOperation + Result int32 // HResult + ResultInfo AnyInString `json:",omitempty"` +} + +type ContainerExecuteProcess struct { + RequestBase + Settings ExecuteProcessSettings +} + +type ExecuteProcessSettings struct { + ProcessParameters AnyInString + StdioRelaySettings *ExecuteProcessStdioRelaySettings `json:",omitempty"` + VsockStdioRelaySettings *ExecuteProcessVsockStdioRelaySettings `json:",omitempty"` +} + +type ExecuteProcessStdioRelaySettings struct { + StdIn *guid.GUID `json:",omitempty"` + StdOut *guid.GUID `json:",omitempty"` + StdErr *guid.GUID `json:",omitempty"` +} + +type ExecuteProcessVsockStdioRelaySettings struct { + StdIn uint32 `json:",omitempty"` + StdOut uint32 `json:",omitempty"` + StdErr uint32 `json:",omitempty"` +} + +type ContainerResizeConsole struct { + RequestBase + ProcessID uint32 `json:"ProcessId"` + Height uint16 + Width uint16 +} + +type ContainerWaitForProcess struct { + RequestBase + ProcessID uint32 `json:"ProcessId"` + TimeoutInMs uint32 +} + +type ContainerSignalProcess struct { + RequestBase + ProcessID uint32 `json:"ProcessId"` + Options interface{} `json:",omitempty"` +} + +type ContainerPropertiesQuery schema1.PropertyQuery + +func (q *ContainerPropertiesQuery) MarshalText() ([]byte, error) { + return json.Marshal((*schema1.PropertyQuery)(q)) +} + +func (q *ContainerPropertiesQuery) UnmarshalText(b []byte) error { + return json.Unmarshal(b, (*schema1.PropertyQuery)(q)) +} + +type ContainerPropertiesQueryV2 hcsschema.PropertyQuery + +func (q *ContainerPropertiesQueryV2) MarshalText() ([]byte, error) { + return json.Marshal((*hcsschema.PropertyQuery)(q)) +} + +func (q *ContainerPropertiesQueryV2) UnmarshalText(b []byte) error { + return json.Unmarshal(b, (*hcsschema.PropertyQuery)(q)) +} + +type ContainerGetProperties struct { + RequestBase + Query ContainerPropertiesQuery +} + +type ContainerGetPropertiesV2 struct { + RequestBase + Query ContainerPropertiesQueryV2 +} + +type ContainerModifySettings struct { + RequestBase + Request interface{} +} + +type GcsCapabilities struct { + SendHostCreateMessage bool + SendHostStartMessage bool + HvSocketConfigOnStartup bool + SendLifecycleNotifications bool + SupportedSchemaVersions []hcsschema.Version + RuntimeOsType string + GuestDefinedCapabilities json.RawMessage +} + +type ContainerCreateResponse struct { + ResponseBase +} + +type ContainerExecuteProcessResponse struct { + ResponseBase + ProcessID uint32 `json:"ProcessId"` +} + +type ContainerWaitForProcessResponse struct { + ResponseBase + ExitCode uint32 +} + +type ContainerProperties schema1.ContainerProperties + +func (p *ContainerProperties) MarshalText() ([]byte, error) { + return json.Marshal((*schema1.ContainerProperties)(p)) +} + +func (p *ContainerProperties) UnmarshalText(b []byte) error { + return json.Unmarshal(b, (*schema1.ContainerProperties)(p)) +} + +type ContainerPropertiesV2 hcsschema.Properties + +func (p *ContainerPropertiesV2) MarshalText() ([]byte, error) { + return json.Marshal((*hcsschema.Properties)(p)) +} + +func (p *ContainerPropertiesV2) UnmarshalText(b []byte) error { + return json.Unmarshal(b, (*hcsschema.Properties)(p)) +} + +type ContainerGetPropertiesResponse struct { + ResponseBase + Properties ContainerProperties +} + +type ContainerGetPropertiesResponseV2 struct { + ResponseBase + Properties ContainerPropertiesV2 +} diff --git a/internal/gcs/protocol.go b/internal/gcs/protocol.go deleted file mode 100644 index 7aeeb4991f..0000000000 --- a/internal/gcs/protocol.go +++ /dev/null @@ -1,371 +0,0 @@ -//go:build windows - -package gcs - -import ( - "encoding/json" - "fmt" - "strconv" - - "github.com/Microsoft/go-winio/pkg/guid" - "github.com/Microsoft/hcsshim/internal/hcs/schema1" - hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" -) - -// LinuxGcsVsockPort is the vsock port number that the Linux GCS will -// connect to. -const LinuxGcsVsockPort = 0x40000000 - -// WindowsGcsHvsockServiceID is the hvsock service ID that the Windows GCS -// will connect to. -var WindowsGcsHvsockServiceID = guid.GUID{ - Data1: 0xacef5661, - Data2: 0x84a1, - Data3: 0x4e44, - Data4: [8]uint8{0x85, 0x6b, 0x62, 0x45, 0xe6, 0x9f, 0x46, 0x20}, -} - -// WindowsGcsHvHostID is the hvsock address for the parent of the VM running the GCS -var WindowsGcsHvHostID = guid.GUID{ - Data1: 0x894cc2d6, - Data2: 0x9d79, - Data3: 0x424f, - Data4: [8]uint8{0x93, 0xfe, 0x42, 0x96, 0x9a, 0xe6, 0xd8, 0xd1}, -} - -type anyInString struct { - Value interface{} -} - -func (a *anyInString) MarshalText() ([]byte, error) { - return json.Marshal(a.Value) -} - -func (a *anyInString) UnmarshalText(b []byte) error { - return json.Unmarshal(b, &a.Value) -} - -type rpcProc uint32 - -const ( - rpcCreate rpcProc = (iota+1)<<8 | 1 - rpcStart - rpcShutdownGraceful - rpcShutdownForced - rpcExecuteProcess - rpcWaitForProcess - rpcSignalProcess - rpcResizeConsole - rpcGetProperties - rpcModifySettings - rpcNegotiateProtocol - rpcDumpStacks - rpcDeleteContainerState - rpcUpdateContainer - rpcLifecycleNotification -) - -func (rpc rpcProc) String() string { - switch rpc { - case rpcCreate: - return "Create" - case rpcStart: - return "Start" - case rpcShutdownGraceful: - return "ShutdownGraceful" - case rpcShutdownForced: - return "ShutdownForced" - case rpcExecuteProcess: - return "ExecuteProcess" - case rpcWaitForProcess: - return "WaitForProcess" - case rpcSignalProcess: - return "SignalProcess" - case rpcResizeConsole: - return "ResizeConsole" - case rpcGetProperties: - return "GetProperties" - case rpcModifySettings: - return "ModifySettings" - case rpcNegotiateProtocol: - return "NegotiateProtocol" - case rpcDumpStacks: - return "DumpStacks" - case rpcDeleteContainerState: - return "DeleteContainerState" - case rpcUpdateContainer: - return "UpdateContainer" - case rpcLifecycleNotification: - return "LifecycleNotification" - default: - return "0x" + strconv.FormatUint(uint64(rpc), 16) - } -} - -type msgType uint32 - -const ( - msgTypeRequest msgType = 0x10100000 - msgTypeResponse msgType = 0x20100000 - msgTypeNotify msgType = 0x30100000 - msgTypeMask msgType = 0xfff00000 - - notifyContainer = 1<<8 | 1 -) - -func (typ msgType) String() string { - var s string - switch typ & msgTypeMask { - case msgTypeRequest: - s = "Request(" - case msgTypeResponse: - s = "Response(" - case msgTypeNotify: - s = "Notify(" - switch typ - msgTypeNotify { - case notifyContainer: - s += "Container" - default: - s += fmt.Sprintf("%#x", uint32(typ)) - } - return s + ")" - default: - return fmt.Sprintf("%#x", uint32(typ)) - } - s += rpcProc(typ &^ msgTypeMask).String() - return s + ")" -} - -// ocspancontext is the internal JSON representation of the OpenCensus -// `trace.SpanContext` for fowarding to a GCS that supports it. -type ocspancontext struct { - // TraceID is the `hex` encoded string of the OpenCensus - // `SpanContext.TraceID` to propagate to the guest. - TraceID string `json:",omitempty"` - // SpanID is the `hex` encoded string of the OpenCensus `SpanContext.SpanID` - // to propagate to the guest. - SpanID string `json:",omitempty"` - - // TraceOptions is the OpenCensus `SpanContext.TraceOptions` passed through - // to propagate to the guest. - TraceOptions uint32 `json:",omitempty"` - - // Tracestate is the `base64` encoded string of marshaling the OpenCensus - // `SpanContext.TraceState.Entries()` to JSON. - // - // If `SpanContext.Tracestate == nil || - // len(SpanContext.Tracestate.Entries()) == 0` this will be `""`. - Tracestate string `json:",omitempty"` -} - -type requestBase struct { - ContainerID string `json:"ContainerId"` - ActivityID guid.GUID `json:"ActivityId"` - - // OpenCensusSpanContext is the encoded OpenCensus `trace.SpanContext` if - // set when making the request. - // - // NOTE: This is not a part of the protocol but because its a JSON protocol - // adding fields is a non-breaking change. If the guest supports it this is - // just additive context. - OpenCensusSpanContext *ocspancontext `json:"ocsc,omitempty"` -} - -func (req *requestBase) Base() *requestBase { - return req -} - -type responseBase struct { - Result int32 // HResult - ErrorMessage string `json:",omitempty"` - ActivityID guid.GUID `json:"ActivityId,omitempty"` - ErrorRecords []errorRecord `json:",omitempty"` -} - -type errorRecord struct { - Result int32 // HResult - Message string - StackTrace string `json:",omitempty"` - ModuleName string - FileName string - Line uint32 - FunctionName string `json:",omitempty"` -} - -func (resp *responseBase) Base() *responseBase { - return resp -} - -type negotiateProtocolRequest struct { - requestBase - MinimumVersion uint32 - MaximumVersion uint32 -} - -type negotiateProtocolResponse struct { - responseBase - Version uint32 `json:",omitempty"` - Capabilities gcsCapabilities `json:",omitempty"` -} - -type dumpStacksRequest struct { - requestBase -} - -type dumpStacksResponse struct { - responseBase - GuestStacks string -} - -type deleteContainerStateRequest struct { - requestBase -} - -type containerCreate struct { - requestBase - ContainerConfig anyInString -} - -type uvmConfig struct { - SystemType string // must be "Container" - TimeZoneInformation *hcsschema.TimeZoneInformation -} - -type containerNotification struct { - requestBase - Type string // Compute.System.NotificationType - Operation string // Compute.System.ActiveOperation - Result int32 // HResult - ResultInfo anyInString `json:",omitempty"` -} - -type containerExecuteProcess struct { - requestBase - Settings executeProcessSettings -} - -type executeProcessSettings struct { - ProcessParameters anyInString - StdioRelaySettings *executeProcessStdioRelaySettings `json:",omitempty"` - VsockStdioRelaySettings *executeProcessVsockStdioRelaySettings `json:",omitempty"` -} - -type executeProcessStdioRelaySettings struct { - StdIn *guid.GUID `json:",omitempty"` - StdOut *guid.GUID `json:",omitempty"` - StdErr *guid.GUID `json:",omitempty"` -} - -type executeProcessVsockStdioRelaySettings struct { - StdIn uint32 `json:",omitempty"` - StdOut uint32 `json:",omitempty"` - StdErr uint32 `json:",omitempty"` -} - -type containerResizeConsole struct { - requestBase - ProcessID uint32 `json:"ProcessId"` - Height uint16 - Width uint16 -} - -type containerWaitForProcess struct { - requestBase - ProcessID uint32 `json:"ProcessId"` - TimeoutInMs uint32 -} - -type containerSignalProcess struct { - requestBase - ProcessID uint32 `json:"ProcessId"` - Options interface{} `json:",omitempty"` -} - -type containerPropertiesQuery schema1.PropertyQuery - -func (q *containerPropertiesQuery) MarshalText() ([]byte, error) { - return json.Marshal((*schema1.PropertyQuery)(q)) -} - -func (q *containerPropertiesQuery) UnmarshalText(b []byte) error { - return json.Unmarshal(b, (*schema1.PropertyQuery)(q)) -} - -type containerPropertiesQueryV2 hcsschema.PropertyQuery - -func (q *containerPropertiesQueryV2) MarshalText() ([]byte, error) { - return json.Marshal((*hcsschema.PropertyQuery)(q)) -} - -func (q *containerPropertiesQueryV2) UnmarshalText(b []byte) error { - return json.Unmarshal(b, (*hcsschema.PropertyQuery)(q)) -} - -type containerGetProperties struct { - requestBase - Query containerPropertiesQuery -} - -type containerGetPropertiesV2 struct { - requestBase - Query containerPropertiesQueryV2 -} - -type containerModifySettings struct { - requestBase - Request interface{} -} - -type gcsCapabilities struct { - SendHostCreateMessage bool - SendHostStartMessage bool - HvSocketConfigOnStartup bool - SendLifecycleNotifications bool - SupportedSchemaVersions []hcsschema.Version - RuntimeOsType string - GuestDefinedCapabilities json.RawMessage -} - -type containerCreateResponse struct { - responseBase -} - -type containerExecuteProcessResponse struct { - responseBase - ProcessID uint32 `json:"ProcessId"` -} - -type containerWaitForProcessResponse struct { - responseBase - ExitCode uint32 -} - -type containerProperties schema1.ContainerProperties - -func (p *containerProperties) MarshalText() ([]byte, error) { - return json.Marshal((*schema1.ContainerProperties)(p)) -} - -func (p *containerProperties) UnmarshalText(b []byte) error { - return json.Unmarshal(b, (*schema1.ContainerProperties)(p)) -} - -type containerPropertiesV2 hcsschema.Properties - -func (p *containerPropertiesV2) MarshalText() ([]byte, error) { - return json.Marshal((*hcsschema.Properties)(p)) -} - -func (p *containerPropertiesV2) UnmarshalText(b []byte) error { - return json.Unmarshal(b, (*hcsschema.Properties)(p)) -} - -type containerGetPropertiesResponse struct { - responseBase - Properties containerProperties -} - -type containerGetPropertiesResponseV2 struct { - responseBase - Properties containerPropertiesV2 -} diff --git a/internal/guest/bridge/bridge.go b/internal/guest/bridge/bridge.go index f14663344f..024c7108b9 100644 --- a/internal/guest/bridge/bridge.go +++ b/internal/guest/bridge/bridge.go @@ -11,9 +11,7 @@ import ( "encoding/json" "fmt" "io" - "math" "os" - "strconv" "sync" "sync/atomic" "time" @@ -23,7 +21,8 @@ import ( "go.opencensus.io/trace" "go.opencensus.io/trace/tracestate" - "github.com/Microsoft/hcsshim/internal/guest/gcserr" + "github.com/Microsoft/hcsshim/internal/bridgeutils/commonutils" + "github.com/Microsoft/hcsshim/internal/bridgeutils/gcserr" "github.com/Microsoft/hcsshim/internal/guest/prot" "github.com/Microsoft/hcsshim/internal/guest/runtime/hcsv2" "github.com/Microsoft/hcsshim/internal/log" @@ -360,7 +359,7 @@ func (b *Bridge) ListenAndServe(bridgeIn io.ReadCloser, bridgeOut io.WriteCloser if span != nil { oc.SetSpanStatus(span, err) } - setErrorForResponseBase(resp.Base(), err) + setErrorForResponseBase(resp.Base(), err, "gcs" /* moduleName */) } br.response = resp b.responseChan <- br @@ -446,45 +445,9 @@ func (b *Bridge) PublishNotification(n *prot.ContainerNotification) { // setErrorForResponseBase modifies the passed-in MessageResponseBase to // contain information pertaining to the given error. -func setErrorForResponseBase(response *prot.MessageResponseBase, errForResponse error) { - errorMessage := errForResponse.Error() - stackString := "" - fileName := "" - // We use -1 as a sentinel if no line number found (or it cannot be parsed), - // but that will ultimately end up as [math.MaxUint32], so set it to that explicitly. - // (Still keep using -1 for backwards compatibility ...) - lineNumber := uint32(math.MaxUint32) - functionName := "" - if stack := gcserr.BaseStackTrace(errForResponse); stack != nil { - bottomFrame := stack[0] - stackString = fmt.Sprintf("%+v", stack) - fileName = fmt.Sprintf("%s", bottomFrame) - lineNumberStr := fmt.Sprintf("%d", bottomFrame) - if n, err := strconv.ParseUint(lineNumberStr, 10, 32); err == nil { - lineNumber = uint32(n) - } else { - logrus.WithFields(logrus.Fields{ - "line-number": lineNumberStr, - logrus.ErrorKey: err, - }).Error("opengcs::bridge::setErrorForResponseBase - failed to parse line number, using -1 instead") - } - functionName = fmt.Sprintf("%n", bottomFrame) - } - hresult, err := gcserr.GetHresult(errForResponse) - if err != nil { - // Default to using the generic failure HRESULT. - hresult = gcserr.HrFail - } +func setErrorForResponseBase(response *prot.MessageResponseBase, errForResponse error, moduleName string) { + hresult, errorMessage, newRecord := commonutils.SetErrorForResponseBaseUtil(errForResponse, moduleName) response.Result = int32(hresult) response.ErrorMessage = errorMessage - newRecord := prot.ErrorRecord{ - Result: int32(hresult), - Message: errorMessage, - StackTrace: stackString, - ModuleName: "gcs", - FileName: fileName, - Line: lineNumber, - FunctionName: functionName, - } response.ErrorRecords = append(response.ErrorRecords, newRecord) } diff --git a/internal/guest/bridge/bridge_unit_test.go b/internal/guest/bridge/bridge_unit_test.go index 67f583da05..613eaf326b 100644 --- a/internal/guest/bridge/bridge_unit_test.go +++ b/internal/guest/bridge/bridge_unit_test.go @@ -12,7 +12,7 @@ import ( "sync" "testing" - "github.com/Microsoft/hcsshim/internal/guest/gcserr" + "github.com/Microsoft/hcsshim/internal/bridgeutils/gcserr" "github.com/Microsoft/hcsshim/internal/guest/prot" "github.com/Microsoft/hcsshim/internal/guest/transport" "github.com/pkg/errors" diff --git a/internal/guest/bridge/bridge_v2.go b/internal/guest/bridge/bridge_v2.go index f9712abc9d..800094e549 100644 --- a/internal/guest/bridge/bridge_v2.go +++ b/internal/guest/bridge/bridge_v2.go @@ -13,8 +13,8 @@ import ( "go.opencensus.io/trace" "golang.org/x/sys/unix" - "github.com/Microsoft/hcsshim/internal/guest/commonutils" - "github.com/Microsoft/hcsshim/internal/guest/gcserr" + "github.com/Microsoft/hcsshim/internal/bridgeutils/commonutils" + "github.com/Microsoft/hcsshim/internal/bridgeutils/gcserr" "github.com/Microsoft/hcsshim/internal/guest/prot" "github.com/Microsoft/hcsshim/internal/guest/runtime/hcsv2" "github.com/Microsoft/hcsshim/internal/guest/stdio" diff --git a/internal/guest/commonutils/utilities.go b/internal/guest/commonutils/utilities.go deleted file mode 100644 index adcf70e6c2..0000000000 --- a/internal/guest/commonutils/utilities.go +++ /dev/null @@ -1,26 +0,0 @@ -package commonutils - -import ( - "encoding/json" - "io" - - "github.com/Microsoft/hcsshim/internal/guest/gcserr" -) - -// UnmarshalJSONWithHresult unmarshals the given data into the given interface, and -// wraps any error returned in an HRESULT error. -func UnmarshalJSONWithHresult(data []byte, v interface{}) error { - if err := json.Unmarshal(data, v); err != nil { - return gcserr.WrapHresult(err, gcserr.HrVmcomputeInvalidJSON) - } - return nil -} - -// DecodeJSONWithHresult decodes the JSON from the given reader into the given -// interface, and wraps any error returned in an HRESULT error. -func DecodeJSONWithHresult(r io.Reader, v interface{}) error { - if err := json.NewDecoder(r).Decode(v); err != nil { - return gcserr.WrapHresult(err, gcserr.HrVmcomputeInvalidJSON) - } - return nil -} diff --git a/internal/guest/prot/protocol.go b/internal/guest/prot/protocol.go index 891891d510..576ac5e5f1 100644 --- a/internal/guest/prot/protocol.go +++ b/internal/guest/prot/protocol.go @@ -11,7 +11,7 @@ import ( oci "github.com/opencontainers/runtime-spec/specs-go" "github.com/pkg/errors" - "github.com/Microsoft/hcsshim/internal/guest/commonutils" + "github.com/Microsoft/hcsshim/internal/bridgeutils/commonutils" hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" "github.com/Microsoft/hcsshim/internal/protocol/guestrequest" "github.com/Microsoft/hcsshim/internal/protocol/guestresource" @@ -501,9 +501,9 @@ type ResourceModificationRequestResponse struct { Settings interface{} `json:",omitempty"` } -// ContainerModifySettings is the message from the HCS specifying how a certain +// containerModifySettings is the message from the HCS specifying how a certain // container resource should be modified. -type ContainerModifySettings struct { +type containerModifySettings struct { MessageBase Request interface{} } @@ -512,9 +512,9 @@ type ContainerModifySettings struct { // ContainerModifySettings message. This function is required because properties // such as `Settings` can be of many types identified by the `ResourceType` and // require dynamic unmarshalling. -func UnmarshalContainerModifySettings(b []byte) (*ContainerModifySettings, error) { +func UnmarshalContainerModifySettings(b []byte) (*containerModifySettings, error) { // Unmarshal the message. - var request ContainerModifySettings + var request containerModifySettings var requestRawSettings json.RawMessage request.Request = &requestRawSettings if err := commonutils.UnmarshalJSONWithHresult(b, &request); err != nil { @@ -601,26 +601,13 @@ func UnmarshalContainerModifySettings(b []byte) (*ContainerModifySettings, error return &request, nil } -// ErrorRecord represents a single error to be reported back to the HCS. It -// allows for specifying information about the source of the error, as well as -// an error message and stack trace. -type ErrorRecord struct { - Result int32 - Message string - StackTrace string `json:",omitempty"` - ModuleName string - FileName string - Line uint32 - FunctionName string `json:",omitempty"` -} - // MessageResponseBase is the base type embedded in all messages sent from the // GCS to the HCS except for ContainerNotification. type MessageResponseBase struct { Result int32 - ActivityID string `json:"ActivityId,omitempty"` - ErrorMessage string `json:",omitempty"` // Only used by hcsshim external bridge - ErrorRecords []ErrorRecord `json:",omitempty"` + ActivityID string `json:"ActivityId,omitempty"` + ErrorMessage string `json:",omitempty"` // Only used by hcsshim external bridge + ErrorRecords []commonutils.ErrorRecord `json:",omitempty"` } // Base returns the response base by reference. diff --git a/internal/guest/runtime/hcsv2/container.go b/internal/guest/runtime/hcsv2/container.go index d1094f7e9a..79dd2a732e 100644 --- a/internal/guest/runtime/hcsv2/container.go +++ b/internal/guest/runtime/hcsv2/container.go @@ -18,7 +18,7 @@ import ( "github.com/sirupsen/logrus" "go.opencensus.io/trace" - "github.com/Microsoft/hcsshim/internal/guest/gcserr" + "github.com/Microsoft/hcsshim/internal/bridgeutils/gcserr" "github.com/Microsoft/hcsshim/internal/guest/prot" "github.com/Microsoft/hcsshim/internal/guest/runtime" specGuest "github.com/Microsoft/hcsshim/internal/guest/spec" diff --git a/internal/guest/runtime/hcsv2/network.go b/internal/guest/runtime/hcsv2/network.go index 136f544153..f6052e84f9 100644 --- a/internal/guest/runtime/hcsv2/network.go +++ b/internal/guest/runtime/hcsv2/network.go @@ -14,7 +14,7 @@ import ( "github.com/vishvananda/netns" "go.opencensus.io/trace" - "github.com/Microsoft/hcsshim/internal/guest/gcserr" + "github.com/Microsoft/hcsshim/internal/bridgeutils/gcserr" "github.com/Microsoft/hcsshim/internal/guest/network" "github.com/Microsoft/hcsshim/internal/oc" "github.com/Microsoft/hcsshim/internal/protocol/guestresource" diff --git a/internal/guest/runtime/hcsv2/process.go b/internal/guest/runtime/hcsv2/process.go index e29e6e62f7..e94c9792f6 100644 --- a/internal/guest/runtime/hcsv2/process.go +++ b/internal/guest/runtime/hcsv2/process.go @@ -10,7 +10,7 @@ import ( "sync" "syscall" - "github.com/Microsoft/hcsshim/internal/guest/gcserr" + "github.com/Microsoft/hcsshim/internal/bridgeutils/gcserr" "github.com/Microsoft/hcsshim/internal/guest/runtime" "github.com/Microsoft/hcsshim/internal/guest/stdio" "github.com/Microsoft/hcsshim/internal/log" diff --git a/internal/guest/runtime/hcsv2/uvm.go b/internal/guest/runtime/hcsv2/uvm.go index 11db7b4534..3c8e4eee39 100644 --- a/internal/guest/runtime/hcsv2/uvm.go +++ b/internal/guest/runtime/hcsv2/uvm.go @@ -22,16 +22,8 @@ import ( "github.com/Microsoft/cosesign1go/pkg/cosesign1" didx509resolver "github.com/Microsoft/didx509go/pkg/did-x509-resolver" - "github.com/Microsoft/hcsshim/pkg/annotations" - "github.com/Microsoft/hcsshim/pkg/securitypolicy" - "github.com/mattn/go-shellwords" - "github.com/opencontainers/runtime-spec/specs-go" - "github.com/pkg/errors" - "github.com/sirupsen/logrus" - "golang.org/x/sys/unix" - + "github.com/Microsoft/hcsshim/internal/bridgeutils/gcserr" "github.com/Microsoft/hcsshim/internal/debug" - "github.com/Microsoft/hcsshim/internal/guest/gcserr" "github.com/Microsoft/hcsshim/internal/guest/policy" "github.com/Microsoft/hcsshim/internal/guest/prot" "github.com/Microsoft/hcsshim/internal/guest/runtime" @@ -49,6 +41,13 @@ import ( "github.com/Microsoft/hcsshim/internal/protocol/guestrequest" "github.com/Microsoft/hcsshim/internal/protocol/guestresource" "github.com/Microsoft/hcsshim/internal/verity" + "github.com/Microsoft/hcsshim/pkg/annotations" + "github.com/Microsoft/hcsshim/pkg/securitypolicy" + "github.com/mattn/go-shellwords" + "github.com/opencontainers/runtime-spec/specs-go" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" ) // UVMContainerID is the ContainerID that will be sent on any prot.MessageBase @@ -121,6 +120,7 @@ func (h *Host) SetConfidentialUVMOptions(ctx context.Context, r *guestresource.L policy.DefaultCRIMounts(), policy.DefaultCRIPrivilegedMounts(), maxErrorMessageLength, + "linux", ) if err != nil { return err diff --git a/internal/guest/runtime/runc/runc.go b/internal/guest/runtime/runc/runc.go index 555fd17a7e..cd11cefdda 100644 --- a/internal/guest/runtime/runc/runc.go +++ b/internal/guest/runtime/runc/runc.go @@ -16,7 +16,7 @@ import ( "github.com/pkg/errors" "golang.org/x/sys/unix" - "github.com/Microsoft/hcsshim/internal/guest/commonutils" + "github.com/Microsoft/hcsshim/internal/bridgeutils/commonutils" "github.com/Microsoft/hcsshim/internal/guest/runtime" "github.com/Microsoft/hcsshim/internal/guest/stdio" ) diff --git a/internal/guest/runtime/runtime.go b/internal/guest/runtime/runtime.go index a8c5231cfc..db24459c27 100644 --- a/internal/guest/runtime/runtime.go +++ b/internal/guest/runtime/runtime.go @@ -8,7 +8,7 @@ import ( "io" "syscall" - "github.com/Microsoft/hcsshim/internal/guest/gcserr" + "github.com/Microsoft/hcsshim/internal/bridgeutils/gcserr" "github.com/Microsoft/hcsshim/internal/guest/stdio" oci "github.com/opencontainers/runtime-spec/specs-go" ) diff --git a/internal/hcs/schema2/cimfs.go b/internal/hcs/schema2/cimfs.go index 52fb62a829..b2a39d133b 100644 --- a/internal/hcs/schema2/cimfs.go +++ b/internal/hcs/schema2/cimfs.go @@ -9,9 +9,18 @@ package hcsschema +import "github.com/Microsoft/go-winio/pkg/guid" + type CimMount struct { ImagePath string `json:"ImagePath,omitempty"` FileSystemName string `json:"FileSystemName,omitempty"` VolumeGuid string `json:"VolumeGuid,omitempty"` MountFlags uint32 `json:"MountFlags,omitempty"` } + +type BlockCIMMount struct { + BlockLUNs []uint32 `json:"BlockLUNs,omitempty"` + CimNames []string `json:"CimNames,omitempty"` + VolumeGuid guid.GUID `json:"VolumeGuid,omitempty"` + MountFlags uint32 `json:"MountFlags,omitempty"` +} diff --git a/internal/hcsoci/create.go b/internal/hcsoci/create.go index ab0fa6c272..cd9ea39a37 100644 --- a/internal/hcsoci/create.go +++ b/internal/hcsoci/create.go @@ -23,6 +23,7 @@ import ( "github.com/Microsoft/hcsshim/internal/layers" "github.com/Microsoft/hcsshim/internal/log" "github.com/Microsoft/hcsshim/internal/oci" + "github.com/Microsoft/hcsshim/internal/protocol/guestresource" "github.com/Microsoft/hcsshim/internal/resources" "github.com/Microsoft/hcsshim/internal/schemaversion" "github.com/Microsoft/hcsshim/internal/uvm" @@ -264,10 +265,25 @@ func CreateContainer(ctx context.Context, createOptions *CreateOptions) (_ cow.C // v1 Argon or Xenon. Pass the document directly to HCS. hcsDocument = v1 } else if coi.HostingSystem != nil { - // v2 Xenon. Pass the container object to the UVM. - gcsDocument = &hcsschema.HostedSystem{ - SchemaVersion: schemaversion.SchemaV21(), - Container: v2, + isCWCOWUVM := false + if createOptions.HostingSystem.WCOWconfidentialUVMOptions != nil { + isCWCOWUVM = true + } + if isCWCOWUVM { + // confidential wcow uvm + gcsDocument = &guestresource.CWCOWHostedSystem{ + Spec: *createOptions.Spec, + CWCOWHostedSystem: hcsschema.HostedSystem{ + SchemaVersion: schemaversion.SchemaV21(), + Container: v2, + }, + } + } else { + // v2 Xenon. Pass the container object to the UVM. + gcsDocument = &hcsschema.HostedSystem{ + SchemaVersion: schemaversion.SchemaV21(), + Container: v2, + } } } else { // v2 Argon. Pass the container object to the HCS. diff --git a/internal/layers/wcow_mount.go b/internal/layers/wcow_mount.go index 9df9f199eb..0dd4b8af5e 100644 --- a/internal/layers/wcow_mount.go +++ b/internal/layers/wcow_mount.go @@ -43,7 +43,7 @@ func MountWCOWLayers(ctx context.Context, containerID string, vm *uvm.UtilityVM, if vm == nil { return mountProcessIsolatedBlockCIMLayers(ctx, containerID, l) } - return nil, nil, fmt.Errorf("hyperv isolated containers aren't supported with block cim layers") + return mountHypervIsolatedBlockCIMLayers(ctx, l, vm, containerID) default: return nil, nil, fmt.Errorf("invalid layer type %T", wl) } @@ -329,6 +329,89 @@ func mountProcessIsolatedBlockCIMLayers(ctx context.Context, containerID string, return mountedLayers, rcl, nil } +func mountHypervIsolatedBlockCIMLayers(ctx context.Context, l *wcowBlockCIMLayers, vm *uvm.UtilityVM, containerID string) (_ *MountedWCOWLayers, _ resources.ResourceCloser, err error) { + ctx, span := oc.StartSpan(ctx, "mountHyperVIsolatedBlockCIMLayers") + defer func() { + oc.SetSpanStatus(span, err) + span.End() + }() + + rcl := &resources.ResourceCloserList{} + defer func() { + if err != nil { + if rErr := rcl.Release(ctx); rErr != nil { + log.G(ctx).WithError(err).Warnf("mount process isolated forked CIM layers, undo failed with: %s", rErr) + } + } + }() + + log.G(ctx).WithFields(logrus.Fields{ + "scratch": l.scratchLayerPath, + "merged layer": l.mergedLayer, + "parent layers": l.parentLayers, + }).Debug("mounting hyperv isolated block CIM layers") + + mountedCIMs, err := vm.MountBlockCIMs(ctx, l.mergedLayer, l.parentLayers, containerID) + if err != nil { + return nil, nil, fmt.Errorf("failed to mount block CIMs in UVM: %w", err) + } + rcl.Add(mountedCIMs) + + // mount the CIM inside UVM now + log.G(ctx).WithField("volume", mountedCIMs.VolumePath).Debug("mounted blockCIM layers for hyperV isolated container") + + hostPath := filepath.Join(l.scratchLayerPath, "sandbox.vhdx") + + scsiMount, err := vm.SCSIManager.AddVirtualDisk(ctx, hostPath, false, vm.ID(), "", &scsi.MountConfig{ + FormatWithRefs: true, + }) + if err != nil { + return nil, nil, fmt.Errorf("failed to add SCSI scratch VHD: %w", err) + } + containerScratchPathInUVM := scsiMount.GuestPath() + rcl.Add(scsiMount) + + log.G(ctx).WithFields(logrus.Fields{ + "hostPath": hostPath, + "uvmPath": containerScratchPathInUVM, + }).Debug("mounted scratch VHD") + + mountedCIMLayerID, err := cimlayer.LayerID(mountedCIMs.VolumePath) + if err != nil { + return nil, nil, fmt.Errorf("failed to get layer ID for mounted block CIM: %w", err) + } + + ml := &MountedWCOWLayers{ + RootFS: containerScratchPathInUVM, + MountedLayerPaths: []MountedWCOWLayer{ + { + LayerID: mountedCIMLayerID, + MountedPath: mountedCIMs.VolumePath, + }, + }, + } + + hcsLayers := []hcsschema.Layer{ + { + Id: mountedCIMLayerID, + Path: filepath.Join(mountedCIMs.VolumePath, "Files"), + }, + } + + err = vm.CombineLayersForCWCOW(ctx, hcsLayers, ml.RootFS, containerID, hcsschema.UnionFS) + if err != nil { + return nil, nil, err + } + log.G(ctx).Debug("hcsshim::mountHyperVIsolatedBlockCIMLayers Succeeded") + + return ml, &wcowIsolatedWCIFSLayerCloser{ + uvm: vm, + guestCombinedLayersPath: ml.RootFS, + scratchMount: scsiMount, + layerClosers: []resources.ResourceCloser{rcl}, + }, nil +} + type wcowIsolatedWCIFSLayerCloser struct { uvm *uvm.UtilityVM guestCombinedLayersPath string @@ -440,7 +523,7 @@ func mountHypervIsolatedWCIFSLayers(ctx context.Context, l *wcowWCIFSLayers, vm }) } - err = vm.CombineLayersWCOW(ctx, hcsLayers, ml.RootFS) + err = vm.CombineLayersWCOW(ctx, hcsLayers, ml.RootFS, hcsschema.WCIFS) if err != nil { return nil, nil, err } diff --git a/internal/layers/wcow_parse.go b/internal/layers/wcow_parse.go index 15a820ccaa..c2da6f5b8d 100644 --- a/internal/layers/wcow_parse.go +++ b/internal/layers/wcow_parse.go @@ -216,33 +216,7 @@ func ParseWCOWLayers(rootfs []*types.Mount, layerFolders []string) (WCOWLayers, } } -// GetWCOWUVMBootFilesFromLayers prepares the UVM boot files from the rootfs or layerFolders. -func GetWCOWUVMBootFilesFromLayers(ctx context.Context, rootfs []*types.Mount, layerFolders []string) (*uvm.WCOWBootFiles, error) { - var parentLayers []string - var scratchLayer string - var err error - - if err = validateRootfsAndLayers(rootfs, layerFolders); err != nil { - return nil, err - } - - if len(layerFolders) > 0 { - parentLayers = layerFolders[:len(layerFolders)-1] - scratchLayer = layerFolders[len(layerFolders)-1] - } else { - m := rootfs[0] - switch m.Type { - case legacyMountType: - parentLayers, err = getOptionAsArray(m, parentLayerPathsFlag) - if err != nil { - return nil, err - } - scratchLayer = m.Source - default: - return nil, fmt.Errorf("mount type '%s' is not supported for UVM boot", m.Type) - } - } - +func makeLegacyWCOWUVMBootFiles(ctx context.Context, scratchLayer string, parentLayers []string) (*uvm.WCOWBootFiles, error) { uvmFolder, err := uvmfolder.LocateUVMFolder(ctx, parentLayers) if err != nil { return nil, fmt.Errorf("failed to locate utility VM folder from layer folders: %w", err) @@ -272,3 +246,32 @@ func GetWCOWUVMBootFilesFromLayers(ctx context.Context, rootfs []*types.Mount, l }, }, nil } + +// GetWCOWUVMBootFilesFromLayers prepares the UVM boot files from the rootfs or layerFolders. +func GetWCOWUVMBootFilesFromLayers(ctx context.Context, rootfs []*types.Mount, layerFolders []string) (*uvm.WCOWBootFiles, error) { + var parentLayers []string + var scratchLayer string + var err error + + if err = validateRootfsAndLayers(rootfs, layerFolders); err != nil { + return nil, err + } + + if len(layerFolders) > 0 { + parentLayers = layerFolders[:len(layerFolders)-1] + scratchLayer = layerFolders[len(layerFolders)-1] + return makeLegacyWCOWUVMBootFiles(ctx, scratchLayer, parentLayers) + } else if rootfs[0].Type == legacyMountType { + parentLayers, err := getOptionAsArray(rootfs[0], parentLayerPathsFlag) + if err != nil { + return nil, err + } + return makeLegacyWCOWUVMBootFiles(ctx, rootfs[0].Source, parentLayers) + } else if rootfs[0].Type == blockCIMMountType { + return &uvm.WCOWBootFiles{ + BootType: uvm.BlockCIMBoot, + BlockCIMFiles: &uvm.BlockCIMBootFiles{}, + }, nil + } + return nil, fmt.Errorf("mount type '%s' is not supported for UVM boot", rootfs[0].Type) +} diff --git a/internal/oci/uvm.go b/internal/oci/uvm.go index cf8de1227d..bc5dff780f 100644 --- a/internal/oci/uvm.go +++ b/internal/oci/uvm.go @@ -5,6 +5,7 @@ package oci import ( "context" "errors" + "fmt" "maps" "strconv" @@ -189,10 +190,30 @@ func handleAnnotationFullyPhysicallyBacked(ctx context.Context, a map[string]str } } -// handleSecurityPolicy handles parsing SecurityPolicy and NoSecurityHardware and setting +// handleWCOWSecurityPolicy handles parsing SecurityPolicy for confidential hyper-v isolated windows containers +func handleWCOWSecurityPolicy(ctx context.Context, a map[string]string, wopts *uvm.OptionsWCOW) error { + wopts.SecurityPolicy = ParseAnnotationsString(a, annotations.WCOWSecurityPolicy, wopts.SecurityPolicy) + wopts.SecurityPolicyEnforcer = ParseAnnotationsString(a, annotations.WCOWSecurityPolicyEnforcer, wopts.SecurityPolicyEnforcer) + // allow actual isolated boot etc to be ignored if we have no hardware. Required for dev + // this is not a security issue as the attestation will fail without a genuine report + noSecurityHardware := ParseAnnotationsBool(ctx, a, annotations.NoSecurityHardware, false) + + // TODO: Process annotations.NoSecurityHardware here for cwcow cases! + if len(wopts.SecurityPolicy) > 0 { + wopts.SecurityPolicyEnabled = true + + if noSecurityHardware { + wopts.NoSecurityHardware = true + } + return uvm.SetDefaultConfidentialWCOWBootConfig(wopts) + } + return nil +} + +// handleLCOWSecurityPolicy handles parsing SecurityPolicy and NoSecurityHardware and setting // implied options from the results. Both LCOW only, not WCOW. -func handleSecurityPolicy(ctx context.Context, a map[string]string, lopts *uvm.OptionsLCOW) { - lopts.SecurityPolicy = ParseAnnotationsString(a, annotations.SecurityPolicy, lopts.SecurityPolicy) +func handleLCOWSecurityPolicy(ctx context.Context, a map[string]string, lopts *uvm.OptionsLCOW) { + lopts.SecurityPolicy = ParseAnnotationsString(a, annotations.LCOWSecurityPolicy, lopts.SecurityPolicy) // allow actual isolated boot etc to be ignored if we have no hardware. Required for dev // this is not a security issue as the attestation will fail without a genuine report noSecurityHardware := ParseAnnotationsBool(ctx, a, annotations.NoSecurityHardware, false) @@ -308,8 +329,8 @@ func SpecToUVMCreateOpts(ctx context.Context, s *specs.Spec, id, owner string) ( lopts.ExtraVSockPorts = ParseAnnotationCommaSeparatedUint32(ctx, s.Annotations, iannotations.ExtraVSockPorts, lopts.ExtraVSockPorts) handleAnnotationBootFilesPath(ctx, s.Annotations, lopts) lopts.EnableScratchEncryption = ParseAnnotationsBool(ctx, s.Annotations, annotations.EncryptedScratchDisk, lopts.EnableScratchEncryption) - lopts.SecurityPolicy = ParseAnnotationsString(s.Annotations, annotations.SecurityPolicy, lopts.SecurityPolicy) - lopts.SecurityPolicyEnforcer = ParseAnnotationsString(s.Annotations, annotations.SecurityPolicyEnforcer, lopts.SecurityPolicyEnforcer) + lopts.SecurityPolicy = ParseAnnotationsString(s.Annotations, annotations.LCOWSecurityPolicy, lopts.SecurityPolicy) + lopts.SecurityPolicyEnforcer = ParseAnnotationsString(s.Annotations, annotations.LCOWSecurityPolicyEnforcer, lopts.SecurityPolicyEnforcer) lopts.UVMReferenceInfoFile = ParseAnnotationsString(s.Annotations, annotations.UVMReferenceInfoFile, lopts.UVMReferenceInfoFile) lopts.KernelBootOptions = ParseAnnotationsString(s.Annotations, annotations.KernelBootOptions, lopts.KernelBootOptions) lopts.DisableTimeSyncService = ParseAnnotationsBool(ctx, s.Annotations, annotations.DisableLCOWTimeSyncService, lopts.DisableTimeSyncService) @@ -320,7 +341,7 @@ func SpecToUVMCreateOpts(ctx context.Context, s *specs.Spec, id, owner string) ( // SecurityPolicy is very sensitive to other settings and will silently change those that are incompatible. // Eg VMPem device count, overridden kernel option cannot be respected. - handleSecurityPolicy(ctx, s.Annotations, lopts) + handleLCOWSecurityPolicy(ctx, s.Annotations, lopts) // override the default GuestState and DmVerityRootFs filenames if specified lopts.GuestStateFile = ParseAnnotationsString(s.Annotations, annotations.GuestStateFile, lopts.GuestStateFile) @@ -343,6 +364,9 @@ func SpecToUVMCreateOpts(ctx context.Context, s *specs.Spec, id, owner string) ( wopts.NoInheritHostTimezone = ParseAnnotationsBool(ctx, s.Annotations, annotations.NoInheritHostTimezone, wopts.NoInheritHostTimezone) wopts.AdditionalRegistryKeys = append(wopts.AdditionalRegistryKeys, parseAdditionalRegistryValues(ctx, s.Annotations)...) handleAnnotationFullyPhysicallyBacked(ctx, s.Annotations, wopts) + if err := handleWCOWSecurityPolicy(ctx, s.Annotations, wopts); err != nil { + return nil, fmt.Errorf("failed to process WCOW security policy: %w", err) + } return wopts, nil } return nil, errors.New("cannot create UVM opts spec is not LCOW or WCOW") diff --git a/internal/protocol/guestresource/resources.go b/internal/protocol/guestresource/resources.go index 89c2003d7e..caf6215ac2 100644 --- a/internal/protocol/guestresource/resources.go +++ b/internal/protocol/guestresource/resources.go @@ -4,6 +4,7 @@ import ( "github.com/Microsoft/hcsshim/internal/protocol/guestrequest" "github.com/opencontainers/runtime-spec/specs-go" + "github.com/Microsoft/go-winio/pkg/guid" hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" ) @@ -26,6 +27,10 @@ const ( // ResourceTypeMappedVirtualDisk is the modify resource type for mapped // virtual disks ResourceTypeMappedVirtualDisk guestrequest.ResourceType = "MappedVirtualDisk" + // ResourceTypeMappedVirtualDiskForContainerScratch is the modify resource type + // specifically for refs formatting and mounting scratch vhds for c-wcow cases only. + ResourceTypeMappedVirtualDiskForContainerScratch guestrequest.ResourceType = "MappedVirtualDiskForContainerScratch" + ResourceTypeWCOWBlockCims guestrequest.ResourceType = "WCOWBlockCims" // ResourceTypeNetwork is the modify resource type for the `NetworkAdapterV2` // device. ResourceTypeNetwork guestrequest.ResourceType = "Network" @@ -33,6 +38,10 @@ const ( // ResourceTypeCombinedLayers is the modify resource type for combined // layers ResourceTypeCombinedLayers guestrequest.ResourceType = "CombinedLayers" + // ResourceTypeCWCOWCombinedLayers is the modify resource type for combined + // layers call for cwcow cases. This resource type wraps containerID around + // ResourceTypeCombinedLayers. + ResourceTypeCWCOWCombinedLayers guestrequest.ResourceType = "CWCOWCombinedLayers" // ResourceTypeVPMemDevice is the modify resource type for VPMem devices ResourceTypeVPMemDevice guestrequest.ResourceType = "VPMemDevice" // ResourceTypeVPCIDevice is the modify resource type for vpci devices @@ -62,9 +71,20 @@ type LCOWCombinedLayers struct { } type WCOWCombinedLayers struct { - ContainerRootPath string `json:"ContainerRootPath,omitempty"` - Layers []hcsschema.Layer `json:"Layers,omitempty"` - ScratchPath string `json:"ScratchPath,omitempty"` + ContainerRootPath string `json:"ContainerRootPath,omitempty"` + Layers []hcsschema.Layer `json:"Layers,omitempty"` + ScratchPath string `json:"ScratchPath,omitempty"` + FilterType hcsschema.FileSystemFilterType `json:"FilterType,omitempty"` +} + +type CWCOWCombinedLayers struct { + ContainerID string `json:"ContainerID,omitempty"` + CombinedLayers WCOWCombinedLayers `json:"CombinedLayers,omitempty"` +} + +type CWCOWHostedSystem struct { + Spec specs.Spec + CWCOWHostedSystem hcsschema.HostedSystem } // Defines the schema for hosted settings passed to GCS and/or OpenGCS @@ -92,6 +112,20 @@ type LCOWMappedVirtualDisk struct { Filesystem string `json:"Filesystem,omitempty"` } +type BlockCIMDevice struct { + CimName string + Lun int32 + Digest string +} + +type WCOWBlockCIMMounts struct { + // BlockCIMs should be ordered from merged CIM followed by Layer n .. layer 1 + BlockCIMs []BlockCIMDevice `json:"BlockCIMs,omitempty"` + VolumeGuid guid.GUID `json:"VolumeGuid,omitempty"` + MountFlags uint32 `json:"MountFlags,omitempty"` + ContainerID string +} + type WCOWMappedVirtualDisk struct { ContainerPath string `json:"ContainerPath,omitempty"` Lun int32 `json:"Lun,omitempty"` @@ -207,3 +241,19 @@ type LCOWConfidentialOptions struct { type LCOWSecurityPolicyFragment struct { Fragment string `json:"Fragment,omitempty"` } + +type WCOWConfidentialOptions struct { + EnforcerType string `json:"EnforcerType,omitempty"` + EncodedSecurityPolicy string `json:"EncodedSecurityPolicy,omitempty"` + // Optional security policy + WCOWSecurityPolicy string + // Set when there is a security policy to apply on actual SNP hardware, use this rathen than checking the string length + WCOWSecurityPolicyEnabled bool + // Set which security policy enforcer to use (open door or rego). This allows for better fallback mechanic. + WCOWSecurityPolicyEnforcer string + NoSecurityHardware bool +} + +type WCOWSecurityPolicyFragment struct { + Fragment string `json:"Fragment,omitempty"` +} diff --git a/internal/pspdriver/pspdriver.go b/internal/pspdriver/pspdriver.go new file mode 100644 index 0000000000..9968cb3685 --- /dev/null +++ b/internal/pspdriver/pspdriver.go @@ -0,0 +1,105 @@ +//go:build windows +// +build windows + +package pspdriver + +import ( + "context" + "fmt" + "syscall" + "unsafe" + + winio "github.com/Microsoft/go-winio" + "github.com/Microsoft/hcsshim/internal/log" + "github.com/pkg/errors" + "golang.org/x/sys/windows/svc" + "golang.org/x/sys/windows/svc/mgr" +) + +const ( + serviceName = "AmdSnpPsp" + snpFirmwareEnvVariable = "SnpGuestReport" + privilegeName = "SeSystemEnvironmentPrivilege" + amdSevSnpGUIDStr = "{4c3bddb9-c2b1-4cbd-9e0c-cb45e9e0e168}" +) + +var ( + kernel32 = syscall.NewLazyDLL("kernel32.dll") + procGetFirmwareVar = kernel32.NewProc("GetFirmwareEnvironmentVariableW") +) + +func StartPSPDriver(ctx context.Context) error { + // Connect to the Service Control Manager + m, err := mgr.Connect() + if err != nil { + return errors.Wrap(err, "Failed to connect to service manager") + } + defer m.Disconnect() + + // Open the service + s, err := m.OpenService(serviceName) + if err != nil { + return errors.Wrapf(err, "Could not access service %q", serviceName) + } + defer s.Close() + + // Start the service + err = s.Start() + if err != nil { + return errors.Wrapf(err, "Could not start service %q", serviceName) + } + + log.G(ctx).Tracef("Service %q started successfully", serviceName) + + // TODO cleanup (kiashok): confirm the running state of the pspdriver + status, err := s.Query() + if err != nil { + return errors.Wrap(err, "could not query service status") + } + + switch status.State { + case svc.Running: + fmt.Println("Service is running.") + case svc.Stopped: + fmt.Println("Service is stopped.") + case svc.StartPending: + fmt.Println("Service is starting.") + case svc.StopPending: + fmt.Println("Service is stopping.") + default: + fmt.Printf("Service state: %v\n", status.State) + } + return nil +} + +// IsSNPEnabled() returns true if SNP support is available. +func IsSNPEnabled(ctx context.Context) bool { + // GetFirmwareEnvironmentVariableW() requires privelege of SeSystemEnvironmentName. + // https://learn.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-getfirmwareenvironmentvariable + err := winio.EnableProcessPrivileges([]string{privilegeName}) + if err != nil { + log.G(ctx).WithError(err).Errorf("enabling privilege failed") + return false + } + + // UEFI variable name for SNP + firmwareEnvVar, _ := syscall.UTF16PtrFromString(snpFirmwareEnvVariable) + amdSnpGUID, _ := syscall.UTF16PtrFromString(amdSevSnpGUIDStr) + // Prepare buffer for data + // SNP report is max of 4KB + buffer := make([]byte, 4096) + + r1, _, err := procGetFirmwareVar.Call( + uintptr(unsafe.Pointer(firmwareEnvVar)), + uintptr(unsafe.Pointer(amdSnpGUID)), + uintptr(unsafe.Pointer(&buffer[0])), + uintptr(len(buffer)), + ) + + if r1 == 0 { + log.G(ctx).WithError(err).Debugf("SNP report not available") + return false + } + + return true +} diff --git a/internal/regopolicyinterpreter/regopolicyinterpreter.go b/internal/regopolicyinterpreter/regopolicyinterpreter.go index 47dbee28ea..047a4a27b7 100644 --- a/internal/regopolicyinterpreter/regopolicyinterpreter.go +++ b/internal/regopolicyinterpreter/regopolicyinterpreter.go @@ -269,6 +269,20 @@ func (m regoMetadata) getOrCreate(name string) map[string]interface{} { return metadata } +func (r *RegoPolicyInterpreter) UpdateOSType(os string) error { + r.dataAndModulesMutex.Lock() + defer r.dataAndModulesMutex.Unlock() + ops := []*regoMetadataOperation{ + { + Action: metadataAdd, + Name: "operatingsystem", + Key: "ostype", + Value: os, + }, + } + return r.updateMetadata(ops) +} + func (r *RegoPolicyInterpreter) updateMetadata(ops []*regoMetadataOperation) error { // dataAndModulesMutex must be held before calling this diff --git a/internal/tools/policyenginesimulator/main.go b/internal/tools/policyenginesimulator/main.go index 7eec934631..e374560939 100644 --- a/internal/tools/policyenginesimulator/main.go +++ b/internal/tools/policyenginesimulator/main.go @@ -90,6 +90,7 @@ func createInterpreter() *rpi.RegoPolicyInterpreter { } r, err := rpi.NewRegoPolicyInterpreter(policyCode, data) + r.UpdateOSType("linux") if err != nil { log.Fatal(err) } diff --git a/internal/tools/uvmboot/conf_wcow.go b/internal/tools/uvmboot/conf_wcow.go index 7d76b7cf3b..6ff390d1c0 100644 --- a/internal/tools/uvmboot/conf_wcow.go +++ b/internal/tools/uvmboot/conf_wcow.go @@ -19,6 +19,9 @@ const ( vmgsFilePathArgName = "vmgs-path" disableSBArgName = "disable-secure-boot" isolationTypeArgName = "isolation-type" + + // default policy (that allows all operations) used when no policy is provided + allowAllPolicy = "cGFja2FnZSBwb2xpY3kKCmFwaV92ZXJzaW9uIDo9ICIwLjExLjAiCmZyYW1ld29ya192ZXJzaW9uIDo9ICIwLjQuMCIKCm1vdW50X2NpbXMgOj0geyJhbGxvd2VkIjogdHJ1ZX0KbW91bnRfZGV2aWNlIDo9IHsiYWxsb3dlZCI6IHRydWV9Cm1vdW50X292ZXJsYXkgOj0geyJhbGxvd2VkIjogdHJ1ZX0KY3JlYXRlX2NvbnRhaW5lciA6PSB7ImFsbG93ZWQiOiB0cnVlLCAiZW52X2xpc3QiOiBudWxsLCAiYWxsb3dfc3RkaW9fYWNjZXNzIjogdHJ1ZX0KdW5tb3VudF9kZXZpY2UgOj0geyJhbGxvd2VkIjogdHJ1ZX0KdW5tb3VudF9vdmVybGF5IDo9IHsiYWxsb3dlZCI6IHRydWV9CmV4ZWNfaW5fY29udGFpbmVyIDo9IHsiYWxsb3dlZCI6IHRydWUsICJlbnZfbGlzdCI6IG51bGx9CmV4ZWNfZXh0ZXJuYWwgOj0geyJhbGxvd2VkIjogdHJ1ZSwgImVudl9saXN0IjogbnVsbCwgImFsbG93X3N0ZGlvX2FjY2VzcyI6IHRydWV9CnNodXRkb3duX2NvbnRhaW5lciA6PSB7ImFsbG93ZWQiOiB0cnVlfQpzaWduYWxfY29udGFpbmVyX3Byb2Nlc3MgOj0geyJhbGxvd2VkIjogdHJ1ZX0KcGxhbjlfbW91bnQgOj0geyJhbGxvd2VkIjogdHJ1ZX0KcGxhbjlfdW5tb3VudCA6PSB7ImFsbG93ZWQiOiB0cnVlfQpnZXRfcHJvcGVydGllcyA6PSB7ImFsbG93ZWQiOiB0cnVlfQpkdW1wX3N0YWNrcyA6PSB7ImFsbG93ZWQiOiB0cnVlfQpydW50aW1lX2xvZ2dpbmcgOj0geyJhbGxvd2VkIjogdHJ1ZX0KbG9hZF9mcmFnbWVudCA6PSB7ImFsbG93ZWQiOiB0cnVlfQpzY3JhdGNoX21vdW50IDo9IHsiYWxsb3dlZCI6IHRydWV9CnNjcmF0Y2hfdW5tb3VudCA6PSB7ImFsbG93ZWQiOiB0cnVlfQo=" ) var ( @@ -28,6 +31,7 @@ var ( cwcowVMGSPath string cwcowDisableSecureBoot bool cwcowIsolationMode string + cwcowSecurityPolicy string ) var cwcowCommand = cli.Command{ @@ -79,6 +83,16 @@ var cwcowCommand = cli.Command{ Destination: &cwcowIsolationMode, Required: true, }, + cli.StringFlag{ + Name: securityPolicyArgName, + Usage: "Security policy that should be enforced inside the UVM. If none is provided, default policy that allows all operations will be used.", + Destination: &cwcowSecurityPolicy, + Value: allowAllPolicy, + }, + cli.BoolFlag{ + Name: securityHardwareFlag, + Usage: "If set, UVM won't boot on non-SNP hardware. Set to false by default", + }, }, Action: func(c *cli.Context) error { runMany(c, func(id string) error { @@ -91,6 +105,11 @@ var cwcowCommand = cli.Command{ // confidential specific options options.SecurityPolicyEnabled = true + options.SecurityPolicy = cwcowSecurityPolicy + options.NoSecurityHardware = true + if c.IsSet(securityHardwareFlag) { + options.NoSecurityHardware = false + } options.DisableSecureBoot = cwcowDisableSecureBoot options.GuestStateFilePath = cwcowVMGSPath options.IsolationType = cwcowIsolationMode diff --git a/internal/uvm/cimfs.go b/internal/uvm/cimfs.go new file mode 100644 index 0000000000..b97ee0d00b --- /dev/null +++ b/internal/uvm/cimfs.go @@ -0,0 +1,96 @@ +//go:build windows +// +build windows + +package uvm + +import ( + "context" + "fmt" + + "github.com/Microsoft/go-winio/pkg/guid" + "github.com/Microsoft/hcsshim/internal/log" + "github.com/Microsoft/hcsshim/internal/protocol/guestrequest" + "github.com/Microsoft/hcsshim/internal/protocol/guestresource" + "github.com/Microsoft/hcsshim/internal/uvm/scsi" + "github.com/Microsoft/hcsshim/pkg/cimfs" + "github.com/sirupsen/logrus" +) + +type UVMMountedBlockCIMs struct { + scsiMounts []*scsi.Mount + // Volume Path inside the UVM at which the CIMs are mounted + VolumePath string +} + +func (umb *UVMMountedBlockCIMs) Release(ctx context.Context) error { + for i := len(umb.scsiMounts) - 1; i >= 0; i-- { + if err := umb.scsiMounts[i].Release(ctx); err != nil { + return err + } + } + return nil +} + +// mergedCIM can be nil, +// sourceCIMs MUST be in the top to bottom order +func (uvm *UtilityVM) MountBlockCIMs(ctx context.Context, mergedCIM *cimfs.BlockCIM, sourceCIMs []*cimfs.BlockCIM, containerID string) (_ *UVMMountedBlockCIMs, err error) { + volumeGUID, err := guid.NewV4() + if err != nil { + return nil, fmt.Errorf("generated cim mount GUID: %w", err) + } + + layersToAttach := sourceCIMs + if mergedCIM != nil { + layersToAttach = append([]*cimfs.BlockCIM{mergedCIM}, sourceCIMs...) + } + + settings := &guestresource.WCOWBlockCIMMounts{ + BlockCIMs: []guestresource.BlockCIMDevice{}, + VolumeGuid: volumeGUID, + MountFlags: cimfs.CimMountVerifiedCim, + ContainerID: containerID, + } + + umb := &UVMMountedBlockCIMs{ + VolumePath: fmt.Sprintf(cimfs.VolumePathFormat, volumeGUID.String()), + scsiMounts: []*scsi.Mount{}, + } + + for _, bcim := range layersToAttach { + sm, err := uvm.SCSIManager.AddVirtualDisk(ctx, bcim.BlockPath, true, uvm.ID(), "", nil) + if err != nil { + return nil, fmt.Errorf("failed to attach block CIM %s: %w", bcim.BlockPath, err) + } + + log.G(ctx).WithFields(logrus.Fields{ + "block path": bcim.BlockPath, + "cim name": bcim.CimName, + "scsi controller": sm.Controller(), + "scsi LUN": sm.LUN(), + }).Debugf("attached block CIM VHD") + + settings.BlockCIMs = append(settings.BlockCIMs, guestresource.BlockCIMDevice{ + CimName: bcim.CimName, + Lun: int32(sm.LUN()), + }) + umb.scsiMounts = append(umb.scsiMounts, sm) + defer func() { + if err != nil { + relErr := sm.Release(ctx) + if relErr != nil { + log.G(ctx).WithError(err).Warnf("cleanup on failure error: %w", relErr) + } + } + }() + } + + guestReq := guestrequest.ModificationRequest{ + ResourceType: guestresource.ResourceTypeWCOWBlockCims, + RequestType: guestrequest.RequestTypeAdd, + Settings: settings, + } + if err := uvm.GuestRequest(ctx, guestReq); err != nil { + return nil, fmt.Errorf("failed to mount the cim: %w", err) + } + return umb, nil +} diff --git a/internal/uvm/combine_layers.go b/internal/uvm/combine_layers.go index 468139c0f7..6d3919a4da 100644 --- a/internal/uvm/combine_layers.go +++ b/internal/uvm/combine_layers.go @@ -10,11 +10,32 @@ import ( "github.com/Microsoft/hcsshim/internal/protocol/guestresource" ) +func (uvm *UtilityVM) CombineLayersForCWCOW(ctx context.Context, layerPaths []hcsschema.Layer, containerRootPath string, containerID string, filterType hcsschema.FileSystemFilterType) error { + if uvm.operatingSystem != "windows" { + return errNotSupported + } + msr := &hcsschema.ModifySettingRequest{ + GuestRequest: guestrequest.ModificationRequest{ + ResourceType: guestresource.ResourceTypeCWCOWCombinedLayers, + RequestType: guestrequest.RequestTypeAdd, + Settings: guestresource.CWCOWCombinedLayers{ + ContainerID: containerID, + CombinedLayers: guestresource.WCOWCombinedLayers{ + ContainerRootPath: containerRootPath, + Layers: layerPaths, + FilterType: filterType, + }, + }, + }, + } + return uvm.modify(ctx, msr) +} + // CombineLayersWCOW combines `layerPaths` with `containerRootPath` into the // container file system. // // Note: `layerPaths` and `containerRootPath` are paths from within the UVM. -func (uvm *UtilityVM) CombineLayersWCOW(ctx context.Context, layerPaths []hcsschema.Layer, containerRootPath string) error { +func (uvm *UtilityVM) CombineLayersWCOW(ctx context.Context, layerPaths []hcsschema.Layer, containerRootPath string, filterType hcsschema.FileSystemFilterType) error { if uvm.operatingSystem != "windows" { return errNotSupported } @@ -25,6 +46,7 @@ func (uvm *UtilityVM) CombineLayersWCOW(ctx context.Context, layerPaths []hcssch Settings: guestresource.WCOWCombinedLayers{ ContainerRootPath: containerRootPath, Layers: layerPaths, + FilterType: filterType, }, }, } diff --git a/internal/uvm/create_lcow.go b/internal/uvm/create_lcow.go index 1f310f8d60..338c742fc2 100644 --- a/internal/uvm/create_lcow.go +++ b/internal/uvm/create_lcow.go @@ -20,7 +20,7 @@ import ( "go.opencensus.io/trace" "github.com/Microsoft/hcsshim/internal/copyfile" - "github.com/Microsoft/hcsshim/internal/gcs" + "github.com/Microsoft/hcsshim/internal/gcs/prot" hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" "github.com/Microsoft/hcsshim/internal/log" "github.com/Microsoft/hcsshim/internal/logfields" @@ -438,7 +438,7 @@ func makeLCOWVMGSDoc(ctx context.Context, opts *OptionsLCOW, uvm *UtilityVM) (_ // entropyVsockPort - 1 is the entropy port, // linuxLogVsockPort - 109 used by vsockexec to log stdout/stderr logging, // 0x40000000 + 1 (LinuxGcsVsockPort + 1) is the bridge (see guestconnectiuon.go) - hvSockets := []uint32{entropyVsockPort, linuxLogVsockPort, gcs.LinuxGcsVsockPort, gcs.LinuxGcsVsockPort + 1} + hvSockets := []uint32{entropyVsockPort, linuxLogVsockPort, prot.LinuxGcsVsockPort, prot.LinuxGcsVsockPort + 1} hvSockets = append(hvSockets, opts.ExtraVSockPorts...) for _, whichSocket := range hvSockets { key := winio.VsockServiceID(whichSocket).String() @@ -984,7 +984,7 @@ func CreateLCOW(ctx context.Context, opts *OptionsLCOW) (_ *UtilityVM, err error if opts.UseGuestConnection { log.G(ctx).WithField("vmID", uvm.runtimeID).Debug("Using external GCS bridge") - l, err := uvm.listenVsock(gcs.LinuxGcsVsockPort) + l, err := uvm.listenVsock(prot.LinuxGcsVsockPort) if err != nil { return nil, err } diff --git a/internal/uvm/create_wcow.go b/internal/uvm/create_wcow.go index 0b91e42cf2..294535f6ee 100644 --- a/internal/uvm/create_wcow.go +++ b/internal/uvm/create_wcow.go @@ -15,13 +15,14 @@ import ( "github.com/sirupsen/logrus" "go.opencensus.io/trace" - "github.com/Microsoft/hcsshim/internal/gcs" + "github.com/Microsoft/hcsshim/internal/gcs/prot" hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" "github.com/Microsoft/hcsshim/internal/log" "github.com/Microsoft/hcsshim/internal/logfields" "github.com/Microsoft/hcsshim/internal/oc" "github.com/Microsoft/hcsshim/internal/processorinfo" "github.com/Microsoft/hcsshim/internal/protocol/guestrequest" + "github.com/Microsoft/hcsshim/internal/protocol/guestresource" "github.com/Microsoft/hcsshim/internal/schemaversion" "github.com/Microsoft/hcsshim/internal/security" "github.com/Microsoft/hcsshim/internal/uvm/scsi" @@ -30,14 +31,18 @@ import ( ) type ConfidentialWCOWOptions struct { - GuestStateFilePath string // The vmgs file path - SecurityPolicyEnabled bool // Set when there is a security policy to apply on actual SNP hardware, use this rathen than checking the string length - SecurityPolicy string // Optional security policy + GuestStateFilePath string // The vmgs file path + SecurityPolicyEnabled bool // Set when there is a security policy to apply on actual SNP hardware, use this rathen than checking the string length + SecurityPolicy string // Optional security policy + SecurityPolicyEnforcer string /* Below options are only included for testing/debugging purposes - shouldn't be used in regular scenarios */ IsolationType string DisableSecureBoot bool FirmwareParameters string + + // Temp (kiashok): + NoSecurityHardware bool } // OptionsWCOW are the set of options passed to CreateWCOW() to create a utility vm. @@ -81,6 +86,42 @@ func NewDefaultOptionsWCOW(id, owner string) *OptionsWCOW { } } +// SetDefaultConfidentialWCOWBootConfig updates the given WCOW UVM creation options (with the +// default values) so that the created UVM does a confidential boot. +func SetDefaultConfidentialWCOWBootConfig(opts *OptionsWCOW) error { + selfDir, err := filepath.Abs(filepath.Dir(os.Args[0])) + if err != nil { + return fmt.Errorf("failed to get absolute path to shim directory: %w", err) + } + + bootDir := filepath.Join(selfDir, "WindowsBootFiles", "confidential") + opts.GuestStateFilePath = filepath.Join(bootDir, "cwcow.vmgs") + opts.BootFiles = &WCOWBootFiles{ + BootType: BlockCIMBoot, + BlockCIMFiles: &BlockCIMBootFiles{ + BootCIMVHDPath: filepath.Join(bootDir, "rootfs.vhd"), + EFIVHDPath: filepath.Join(bootDir, "boot.vhd"), + ScratchVHDPath: filepath.Join(bootDir, "scratch.vhd"), + }, + } + for _, path := range []string{ + opts.GuestStateFilePath, + opts.BootFiles.BlockCIMFiles.BootCIMVHDPath, + opts.BootFiles.BlockCIMFiles.EFIVHDPath, + opts.BootFiles.BlockCIMFiles.ScratchVHDPath} { + if _, err := os.Stat(path); err != nil { + return fmt.Errorf("failed to stat boot file `%s` for confidential WCOW: %w", path, err) + } + } + + //TODO(ambarve): for testing only remove later + opts.IsolationType = "VirtualizationBasedSecurity" + opts.DisableSecureBoot = false + opts.ConsolePipe = "\\\\.\\pipe\\uvmpipe" + opts.NoSecurityHardware = true + return nil +} + // startExternalGcsListener connects to the GCS service running inside the // UVM. gcsServiceID can either be the service ID of the default GCS that is present in // all UtilityVMs or it can be the service ID of the sidecar GCS that is used mostly in @@ -337,10 +378,6 @@ func prepareSecurityConfigDoc(ctx context.Context, uvm *UtilityVM, opts *Options doc.VirtualMachine.SecuritySettings.Isolation.IsolationType = opts.IsolationType } - if err := wclayer.GrantVmAccess(ctx, uvm.id, opts.GuestStateFilePath); err != nil { - return nil, errors.Wrap(err, "failed to grant vm access to guest state file") - } - doc.VirtualMachine.GuestState = &hcsschema.GuestState{ GuestStateFilePath: opts.GuestStateFilePath, GuestStateFileType: "BlockStorage", @@ -377,8 +414,9 @@ func prepareSecurityConfigDoc(ctx context.Context, uvm *UtilityVM, opts *Options Type_: "VirtualDisk", } doc.VirtualMachine.Devices.Scsi[guestrequest.ScsiControllerGuids[0]].Attachments["1"] = hcsschema.Attachment{ - Path: opts.BootFiles.BlockCIMFiles.EFIVHDPath, - Type_: "VirtualDisk", + Path: opts.BootFiles.BlockCIMFiles.EFIVHDPath, + Type_: "VirtualDisk", + ReadOnly: true, } doc.VirtualMachine.Devices.Scsi[guestrequest.ScsiControllerGuids[0]].Attachments["2"] = hcsschema.Attachment{ Path: opts.BootFiles.BlockCIMFiles.BootCIMVHDPath, @@ -487,6 +525,12 @@ func CreateWCOW(ctx context.Context, opts *OptionsWCOW) (_ *UtilityVM, err error var doc *hcsschema.ComputeSystem if opts.SecurityPolicyEnabled { + uvm.WCOWconfidentialUVMOptions = &guestresource.WCOWConfidentialOptions{ + WCOWSecurityPolicyEnabled: true, + WCOWSecurityPolicy: opts.SecurityPolicy, + WCOWSecurityPolicyEnforcer: opts.SecurityPolicyEnforcer, + NoSecurityHardware: opts.NoSecurityHardware, + } doc, err = prepareSecurityConfigDoc(ctx, uvm, opts) log.G(ctx).Tracef("CreateWCOW prepareSecurityConfigDoc result doc: %v err %v", doc, err) } else { @@ -502,7 +546,7 @@ func CreateWCOW(ctx context.Context, opts *OptionsWCOW) (_ *UtilityVM, err error return nil, fmt.Errorf("error while creating the compute system: %w", err) } - gcsServiceID := gcs.WindowsGcsHvsockServiceID + gcsServiceID := prot.WindowsGcsHvsockServiceID if opts.SecurityPolicyEnabled { gcsServiceID = windowsSidecarGcsHvsockServiceID } diff --git a/internal/uvm/scsi/backend.go b/internal/uvm/scsi/backend.go index 6219a15172..130ee66df6 100644 --- a/internal/uvm/scsi/backend.go +++ b/internal/uvm/scsi/backend.go @@ -170,6 +170,13 @@ func mountRequest(controller, lun uint, path string, config *mountConfig, osType ResourceType: guestresource.ResourceTypeMappedVirtualDisk, RequestType: guestrequest.RequestTypeAdd, } + // This option is set only for cwcow scratch disk mount requests + // where we need to format the disk with refs. + // For refs the scratch disk size should > 30 GB. + if config.formatWithRefs { + req.ResourceType = guestresource.ResourceTypeMappedVirtualDiskForContainerScratch + } + switch osType { case "windows": // We don't check config.readOnly here, as that will still result in the overall attachment being read-only. @@ -185,6 +192,7 @@ func mountRequest(controller, lun uint, path string, config *mountConfig, osType ContainerPath: path, Lun: int32(lun), } + case "linux": req.Settings = guestresource.LCOWMappedVirtualDisk{ MountPath: path, diff --git a/internal/uvm/scsi/manager.go b/internal/uvm/scsi/manager.go index ff8374038e..272ea9c992 100644 --- a/internal/uvm/scsi/manager.go +++ b/internal/uvm/scsi/manager.go @@ -86,6 +86,9 @@ type MountConfig struct { // BlockDev indicates if the device should be mounted as a block device. // This is only supported for LCOW. BlockDev bool + // FormatWithRefs indicates to refs format the disk. + // This is only supported for CWCOW scratch disks. + FormatWithRefs bool } // Mount represents a SCSI device that has been attached to a VM, and potentially @@ -162,6 +165,7 @@ func (m *Manager) AddVirtualDisk( ensureFilesystem: mc.EnsureFilesystem, filesystem: mc.Filesystem, blockDev: mc.BlockDev, + formatWithRefs: mc.FormatWithRefs, } } return m.add(ctx, diff --git a/internal/uvm/scsi/mount.go b/internal/uvm/scsi/mount.go index 68e36f1c9c..4696ac15a5 100644 --- a/internal/uvm/scsi/mount.go +++ b/internal/uvm/scsi/mount.go @@ -45,6 +45,7 @@ type mountConfig struct { options []string ensureFilesystem bool filesystem string + formatWithRefs bool } func (mm *mountManager) mount(ctx context.Context, controller, lun uint, path string, c *mountConfig) (_ string, err error) { diff --git a/internal/uvm/security_policy.go b/internal/uvm/security_policy.go index 195b93d3ce..266601af3d 100644 --- a/internal/uvm/security_policy.go +++ b/internal/uvm/security_policy.go @@ -34,6 +34,59 @@ func WithSecurityPolicyEnforcer(enforcer string) ConfidentialUVMOpt { } } +// TODO (Mahati): Move this block out later +type WCOWConfidentialUVMOpt func(ctx context.Context, r *guestresource.WCOWConfidentialOptions) error + +// WithSecurityPolicy sets the desired security policy for the resource. +func WithWCOWSecurityPolicy(policy string) WCOWConfidentialUVMOpt { + return func(ctx context.Context, r *guestresource.WCOWConfidentialOptions) error { + r.EncodedSecurityPolicy = policy + return nil + } +} + +func WithWCOWNoSecurityHardware(noSecurityHardware bool) WCOWConfidentialUVMOpt { + return func(ctx context.Context, r *guestresource.WCOWConfidentialOptions) error { + r.NoSecurityHardware = noSecurityHardware + return nil + } +} + +// WithSecurityPolicyEnforcer sets the desired enforcer type for the resource. +func WithWCOWSecurityPolicyEnforcer(enforcer string) WCOWConfidentialUVMOpt { + return func(ctx context.Context, r *guestresource.WCOWConfidentialOptions) error { + r.EnforcerType = enforcer + return nil + } +} + +// TODO: Separate this out later +func (uvm *UtilityVM) SetWCOWConfidentialUVMOptions(ctx context.Context, opts ...WCOWConfidentialUVMOpt) error { + if uvm.operatingSystem != "windows" { + return errNotSupported + } + uvm.m.Lock() + defer uvm.m.Unlock() + confOpts := &guestresource.WCOWConfidentialOptions{} + for _, o := range opts { + if err := o(ctx, confOpts); err != nil { + return err + } + } + modification := &hcsschema.ModifySettingRequest{ + RequestType: guestrequest.RequestTypeAdd, + GuestRequest: guestrequest.ModificationRequest{ + ResourceType: guestresource.ResourceTypeSecurityPolicy, + RequestType: guestrequest.RequestTypeAdd, + Settings: *confOpts, + }, + } + if err := uvm.modify(ctx, modification); err != nil { + return fmt.Errorf("uvm::Policy: failed to modify utility VM configuration: %w", err) + } + return nil +} + func base64EncodeFileContents(filePath string) (string, error) { if filePath == "" { return "", nil @@ -88,7 +141,7 @@ func (uvm *UtilityVM) SetConfidentialUVMOptions(ctx context.Context, opts ...Con } } modification := &hcsschema.ModifySettingRequest{ - RequestType: guestrequest.RequestTypeAdd, + //RequestType: guestrequest.RequestTypeAdd, GuestRequest: guestrequest.ModificationRequest{ ResourceType: guestresource.ResourceTypeSecurityPolicy, RequestType: guestrequest.RequestTypeAdd, diff --git a/internal/uvm/start.go b/internal/uvm/start.go index ec8cdf1ee9..4eb7283d58 100644 --- a/internal/uvm/start.go +++ b/internal/uvm/start.go @@ -18,6 +18,7 @@ import ( "golang.org/x/sys/windows" "github.com/Microsoft/hcsshim/internal/gcs" + "github.com/Microsoft/hcsshim/internal/gcs/prot" "github.com/Microsoft/hcsshim/internal/hcs" "github.com/Microsoft/hcsshim/internal/hcs/schema1" hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" @@ -134,7 +135,7 @@ func (uvm *UtilityVM) configureHvSocketForGCS(ctx context.Context) (err error) { hvsocketAddress := &hcsschema.HvSocketAddress{ LocalAddress: uvm.runtimeID.String(), - ParentAddress: gcs.WindowsGcsHvHostID.String(), + ParentAddress: prot.WindowsGcsHvHostID.String(), } conSetupReq := &hcsschema.ModifySettingRequest{ @@ -156,7 +157,7 @@ func (uvm *UtilityVM) configureHvSocketForGCS(ctx context.Context) (err error) { func (uvm *UtilityVM) Start(ctx context.Context) (err error) { // save parent context, without timeout to use in terminate pCtx := ctx - ctx, cancel := context.WithTimeout(pCtx, 2*time.Minute) + ctx, cancel := context.WithTimeout(pCtx, 200*time.Minute) g, gctx := errgroup.WithContext(ctx) defer func() { _ = g.Wait() @@ -333,6 +334,17 @@ func (uvm *UtilityVM) Start(ctx context.Context) (err error) { } } + if uvm.WCOWconfidentialUVMOptions != nil && uvm.OS() == "windows" { + copts := []WCOWConfidentialUVMOpt{ + WithWCOWSecurityPolicy(uvm.WCOWconfidentialUVMOptions.WCOWSecurityPolicy), + WithWCOWSecurityPolicyEnforcer(uvm.WCOWconfidentialUVMOptions.WCOWSecurityPolicyEnforcer), + WithWCOWNoSecurityHardware(uvm.WCOWconfidentialUVMOptions.NoSecurityHardware), + } + if err := uvm.SetWCOWConfidentialUVMOptions(ctx, copts...); err != nil { + return err + } + } + return nil } diff --git a/internal/uvm/types.go b/internal/uvm/types.go index 150b204999..a301e396ac 100644 --- a/internal/uvm/types.go +++ b/internal/uvm/types.go @@ -14,6 +14,7 @@ import ( "github.com/Microsoft/hcsshim/hcn" "github.com/Microsoft/hcsshim/internal/gcs" "github.com/Microsoft/hcsshim/internal/hcs" + "github.com/Microsoft/hcsshim/internal/protocol/guestresource" "github.com/Microsoft/hcsshim/internal/uvm/scsi" ) @@ -142,6 +143,8 @@ type UtilityVM struct { // LCOW only. Indicates whether to use policy based routing when configuring net interfaces in the guest. policyBasedRouting bool + // WCOWconfidentialUVMOptions hold confidential UVM specific options + WCOWconfidentialUVMOptions *guestresource.WCOWConfidentialOptions } func (uvm *UtilityVM) ScratchEncryptionEnabled() bool { diff --git a/internal/wclayer/cim/block_cim_writer.go b/internal/wclayer/cim/block_cim_writer.go index 1e7da68c05..33e9a4c23f 100644 --- a/internal/wclayer/cim/block_cim_writer.go +++ b/internal/wclayer/cim/block_cim_writer.go @@ -5,10 +5,15 @@ package cim import ( "context" "fmt" + "os" "path/filepath" + "strconv" "github.com/Microsoft/go-winio" + "github.com/Microsoft/hcsshim/ext4/tar2ext4" "github.com/Microsoft/hcsshim/internal/log" + "github.com/Microsoft/hcsshim/internal/wclayer" + "github.com/Microsoft/hcsshim/osversion" "github.com/Microsoft/hcsshim/pkg/cimfs" ) @@ -133,3 +138,47 @@ func (cw *BlockCIMLayerWriter) AddLink(name string, target string) error { return nil } + +func (cw *BlockCIMLayerWriter) Close(ctx context.Context) error { + processUtilityVM := false + if cw.hasUtilityVM { + uvmSoftwareHivePath := filepath.Join(cw.layerPath, wclayer.UtilityVMPath, wclayer.RegFilesPath, "SOFTWARE") + osvStr, err := getOsBuildNumberFromRegistry(uvmSoftwareHivePath) + if err != nil { + return fmt.Errorf("read os version string from UtilityVM SOFTWARE hive: %w", err) + } + + osv, err := strconv.ParseUint(osvStr, 10, 16) + if err != nil { + return fmt.Errorf("parse os version string (%s): %w", osvStr, err) + } + + // write this version to a file for future reference by the shim process + if err = wclayer.WriteLayerUvmBuildFile(cw.layerPath, uint16(osv)); err != nil { + return fmt.Errorf("write uvm build version: %w", err) + } + + // TODO(ambarve): use the accurate OS version here. + // CIMFS for hyperV isolated is only supported after WS2025, processing + // UtilityVM layer lower builds will cause failures since those images + // won't have CIMFS specific UVM files (mostly BCD entries required for + // CIMFS) + processUtilityVM = (osv >= osversion.LTSC2025) + log.G(ctx).Debugf("import image os version %d, processing UtilityVM layer: %t\n", osv, processUtilityVM) + } + if err := cw.cimLayerWriter.Close(ctx, processUtilityVM); err != nil { + return fmt.Errorf("failed to close cim layer writer: %w", err) + } + // append footer only after all writers are closed + + blockFile, err := os.OpenFile(cw.layer.BlockPath, os.O_WRONLY, 0777) + if err != nil { + return fmt.Errorf("failed to open block CIM to append VHD footer: %w", err) + } + defer blockFile.Close() + + if err := tar2ext4.ConvertToVhd(blockFile); err != nil { + return fmt.Errorf("failed to append VHD footer: %w", err) + } + return nil +} diff --git a/internal/wclayer/cim/common.go b/internal/wclayer/cim/common.go index 391a5aaeda..0dfeffe531 100644 --- a/internal/wclayer/cim/common.go +++ b/internal/wclayer/cim/common.go @@ -169,7 +169,7 @@ func (cw *cimLayerWriter) Write(b []byte) (int, error) { } // Close finishes the layer writing process and releases any resources. -func (cw *cimLayerWriter) Close(ctx context.Context) (retErr error) { +func (cw *cimLayerWriter) Close(ctx context.Context, processUtilityVM bool) (retErr error) { if err := cw.stdFileWriter.Close(ctx); err != nil { return err } @@ -181,9 +181,6 @@ func (cw *cimLayerWriter) Close(ctx context.Context) (retErr error) { } }() - // We don't support running UtilityVM with CIM layers yet. - processUtilityVM := false - if len(cw.parentLayerPaths) == 0 { if err := cw.processBaseLayer(ctx, processUtilityVM); err != nil { return fmt.Errorf("process base layer: %w", err) diff --git a/internal/wclayer/cim/forked_cim_writer.go b/internal/wclayer/cim/forked_cim_writer.go index 7da052b515..c5c311e4fa 100644 --- a/internal/wclayer/cim/forked_cim_writer.go +++ b/internal/wclayer/cim/forked_cim_writer.go @@ -76,3 +76,9 @@ func (cw *ForkedCimLayerWriter) Remove(name string) error { } return fmt.Errorf("failed to remove file: %w", err) } + +// Close finishes the layer writing process and releases any resources. +func (cw *ForkedCimLayerWriter) Close(ctx context.Context) error { + // we don't support running UVMs with forked CIM layers + return cw.cimLayerWriter.Close(ctx, false) +} diff --git a/internal/wclayer/cim/mount.go b/internal/wclayer/cim/mount.go index 56d0d0ac7d..89fa13ccdc 100644 --- a/internal/wclayer/cim/mount.go +++ b/internal/wclayer/cim/mount.go @@ -108,6 +108,7 @@ func MergeMountBlockCIMLayer(ctx context.Context, mergedLayer *cimfs.BlockCIM, p if err != nil { return "", fmt.Errorf("generated cim mount GUID: %w", err) } + return cimfs.MountMergedBlockCIMs(mergedLayer, parentLayers, mountFlags, volumeGUID) } diff --git a/internal/wclayer/cim/process.go b/internal/wclayer/cim/process.go index ace81122bc..f4ad6cb293 100644 --- a/internal/wclayer/cim/process.go +++ b/internal/wclayer/cim/process.go @@ -16,6 +16,10 @@ import ( // processUtilityVMLayer will handle processing of UVM specific files when we start // supporting UVM based containers with CimFS in the future. func processUtilityVMLayer(ctx context.Context, layerPath string) error { + // TODO(ambarve): + // 1. create a scratch VHD + // 2. create a diff scratch VHD + // 3. create an EFI partition VHD for boot return nil } diff --git a/internal/wclayer/cim/registry.go b/internal/wclayer/cim/registry.go index c95b03ca37..54c7a05607 100644 --- a/internal/wclayer/cim/registry.go +++ b/internal/wclayer/cim/registry.go @@ -5,9 +5,11 @@ package cim import ( "fmt" + "github.com/Microsoft/hcsshim/internal/log" "github.com/Microsoft/hcsshim/internal/winapi" "github.com/Microsoft/hcsshim/osversion" "github.com/pkg/errors" + "github.com/sirupsen/logrus" ) // mergeHive merges the hive located at parentHivePath with the hive located at deltaHivePath and stores @@ -47,3 +49,51 @@ func mergeHive(parentHivePath, deltaHivePath, mergedHivePath string) (err error) } return } + +// getOsBuildNumberFromRegistry fetches the "CurrentBuild" value at path +// "Microsoft\Windows NT\CurrentVersion" from the SOFTWARE registry hive at path +// `regHivePath`. This is used to detect the build version of the uvm. +func getOsBuildNumberFromRegistry(regHivePath string) (_ string, err error) { + var storeHandle, keyHandle winapi.ORHKey + var dataType, dataLen uint32 + keyPath := "Microsoft\\Windows NT\\CurrentVersion" + valueName := "CurrentBuild" + dataLen = 16 // build version string can't be more than 5 wide chars? + dataBuf := make([]byte, dataLen) + + if err = winapi.OROpenHive(regHivePath, &storeHandle); err != nil { + return "", fmt.Errorf("failed to open registry store at %s: %w", regHivePath, err) + } + defer func() { + if closeErr := winapi.ORCloseHive(storeHandle); closeErr != nil { + log.L.WithFields(logrus.Fields{ + "error": closeErr, + "hive": regHivePath, + }).Warnf("failed to close hive") + } + }() + + if err = winapi.OROpenKey(storeHandle, keyPath, &keyHandle); err != nil { + return "", fmt.Errorf("failed to open key at %s: %w", keyPath, err) + } + defer func() { + if closeErr := winapi.ORCloseKey(keyHandle); closeErr != nil { + log.L.WithFields(logrus.Fields{ + "error": closeErr, + "hive": regHivePath, + "key": keyPath, + "value": valueName, + }).Warnf("failed to close hive key") + } + }() + + if err = winapi.ORGetValue(keyHandle, "", valueName, &dataType, &dataBuf[0], &dataLen); err != nil { + return "", fmt.Errorf("failed to get value of %s: %w", valueName, err) + } + + if dataType != uint32(winapi.REG_TYPE_SZ) { + return "", fmt.Errorf("unexpected build number data type (%d)", dataType) + } + + return winapi.ParseUtf16LE(dataBuf[:(dataLen - 2)]), nil +} diff --git a/internal/winapi/cimfs.go b/internal/winapi/cimfs.go index 6c026d9822..cc3d254120 100644 --- a/internal/winapi/cimfs.go +++ b/internal/winapi/cimfs.go @@ -56,3 +56,6 @@ type CimFsImagePath struct { //sys CimMergeMountImage(numCimPaths uint32, backingImagePaths *CimFsImagePath, flags uint32, volumeID *g) (hr error) = cimfs.CimMergeMountImage? //sys CimTombstoneFile(cimFSHandle FsHandle, path string) (hr error) = cimfs.CimTombstoneFile? //sys CimCreateMergeLink(cimFSHandle FsHandle, newPath string, oldPath string) (hr error) = cimfs.CimCreateMergeLink? +//sys CimSealImage(blockCimPath string, hashSize *uint64, fixedHeaderSize *uint64, hash *byte) (hr error) = cimfs.CimSealImage? +//sys CimGetVerificationInformation(blockCimPath string, isSealed *uint32, hashSize *uint64, signatureSize *uint64, fixedHeaderSize *uint64, hash *byte, signature *byte) (hr error) = cimfs.CimGetVerificationInformation? +//sys CimMountVerifiedImage(imagePath string, fsName string, flags uint32, volumeID *g, hashSize uint16, hash *byte) (hr error) = cimfs.CimMountVerifiedImage? diff --git a/internal/winapi/zsyscall_windows.go b/internal/winapi/zsyscall_windows.go index db4fc1c961..a7eea44ec7 100644 --- a/internal/winapi/zsyscall_windows.go +++ b/internal/winapi/zsyscall_windows.go @@ -68,8 +68,11 @@ var ( procCimCreateMergeLink = modcimfs.NewProc("CimCreateMergeLink") procCimDeletePath = modcimfs.NewProc("CimDeletePath") procCimDismountImage = modcimfs.NewProc("CimDismountImage") + procCimGetVerificationInformation = modcimfs.NewProc("CimGetVerificationInformation") procCimMergeMountImage = modcimfs.NewProc("CimMergeMountImage") procCimMountImage = modcimfs.NewProc("CimMountImage") + procCimMountVerifiedImage = modcimfs.NewProc("CimMountVerifiedImage") + procCimSealImage = modcimfs.NewProc("CimSealImage") procCimTombstoneFile = modcimfs.NewProc("CimTombstoneFile") procCimWriteStream = modcimfs.NewProc("CimWriteStream") procSetJobCompartmentId = modiphlpapi.NewProc("SetJobCompartmentId") @@ -491,6 +494,30 @@ func CimDismountImage(volumeID *g) (hr error) { return } +func CimGetVerificationInformation(blockCimPath string, isSealed *uint32, hashSize *uint64, signatureSize *uint64, fixedHeaderSize *uint64, hash *byte, signature *byte) (hr error) { + var _p0 *uint16 + _p0, hr = syscall.UTF16PtrFromString(blockCimPath) + if hr != nil { + return + } + return _CimGetVerificationInformation(_p0, isSealed, hashSize, signatureSize, fixedHeaderSize, hash, signature) +} + +func _CimGetVerificationInformation(blockCimPath *uint16, isSealed *uint32, hashSize *uint64, signatureSize *uint64, fixedHeaderSize *uint64, hash *byte, signature *byte) (hr error) { + hr = procCimGetVerificationInformation.Find() + if hr != nil { + return + } + r0, _, _ := syscall.SyscallN(procCimGetVerificationInformation.Addr(), uintptr(unsafe.Pointer(blockCimPath)), uintptr(unsafe.Pointer(isSealed)), uintptr(unsafe.Pointer(hashSize)), uintptr(unsafe.Pointer(signatureSize)), uintptr(unsafe.Pointer(fixedHeaderSize)), uintptr(unsafe.Pointer(hash)), uintptr(unsafe.Pointer(signature))) + if int32(r0) < 0 { + if r0&0x1fff0000 == 0x00070000 { + r0 &= 0xffff + } + hr = syscall.Errno(r0) + } + return +} + func CimMergeMountImage(numCimPaths uint32, backingImagePaths *CimFsImagePath, flags uint32, volumeID *g) (hr error) { hr = procCimMergeMountImage.Find() if hr != nil { @@ -535,6 +562,59 @@ func _CimMountImage(imagePath *uint16, fsName *uint16, flags uint32, volumeID *g return } +func CimMountVerifiedImage(imagePath string, fsName string, flags uint32, volumeID *g, hashSize uint16, hash *byte) (hr error) { + var _p0 *uint16 + _p0, hr = syscall.UTF16PtrFromString(imagePath) + if hr != nil { + return + } + var _p1 *uint16 + _p1, hr = syscall.UTF16PtrFromString(fsName) + if hr != nil { + return + } + return _CimMountVerifiedImage(_p0, _p1, flags, volumeID, hashSize, hash) +} + +func _CimMountVerifiedImage(imagePath *uint16, fsName *uint16, flags uint32, volumeID *g, hashSize uint16, hash *byte) (hr error) { + hr = procCimMountVerifiedImage.Find() + if hr != nil { + return + } + r0, _, _ := syscall.SyscallN(procCimMountVerifiedImage.Addr(), uintptr(unsafe.Pointer(imagePath)), uintptr(unsafe.Pointer(fsName)), uintptr(flags), uintptr(unsafe.Pointer(volumeID)), uintptr(hashSize), uintptr(unsafe.Pointer(hash))) + if int32(r0) < 0 { + if r0&0x1fff0000 == 0x00070000 { + r0 &= 0xffff + } + hr = syscall.Errno(r0) + } + return +} + +func CimSealImage(blockCimPath string, hashSize *uint64, fixedHeaderSize *uint64, hash *byte) (hr error) { + var _p0 *uint16 + _p0, hr = syscall.UTF16PtrFromString(blockCimPath) + if hr != nil { + return + } + return _CimSealImage(_p0, hashSize, fixedHeaderSize, hash) +} + +func _CimSealImage(blockCimPath *uint16, hashSize *uint64, fixedHeaderSize *uint64, hash *byte) (hr error) { + hr = procCimSealImage.Find() + if hr != nil { + return + } + r0, _, _ := syscall.SyscallN(procCimSealImage.Addr(), uintptr(unsafe.Pointer(blockCimPath)), uintptr(unsafe.Pointer(hashSize)), uintptr(unsafe.Pointer(fixedHeaderSize)), uintptr(unsafe.Pointer(hash))) + if int32(r0) < 0 { + if r0&0x1fff0000 == 0x00070000 { + r0 &= 0xffff + } + hr = syscall.Errno(r0) + } + return +} + func CimTombstoneFile(cimFSHandle FsHandle, path string) (hr error) { var _p0 *uint16 _p0, hr = syscall.UTF16PtrFromString(path) diff --git a/pkg/annotations/annotations.go b/pkg/annotations/annotations.go index 62eac7e80d..75aa1aedfe 100644 --- a/pkg/annotations/annotations.go +++ b/pkg/annotations/annotations.go @@ -221,12 +221,19 @@ const ( // should be encrypted or not. EncryptedScratchDisk = "io.microsoft.virtualmachine.storage.scratch.encrypted" - // SecurityPolicy is used to specify a security policy for opengcs to enforce. - SecurityPolicy = "io.microsoft.virtualmachine.lcow.securitypolicy" + // LCOWSecurityPolicy is used to specify a security policy for opengcs to enforce. + LCOWSecurityPolicy = "io.microsoft.virtualmachine.lcow.securitypolicy" - // SecurityPolicyEnforcer is used to specify which enforcer to initialize (open-door, standard or rego). + // LCOWSecurityPolicyEnforcer is used to specify which enforcer to initialize (open-door, standard or rego). // This allows for better fallback mechanics. - SecurityPolicyEnforcer = "io.microsoft.virtualmachine.lcow.enforcer" + LCOWSecurityPolicyEnforcer = "io.microsoft.virtualmachine.lcow.enforcer" + + // WCOW SecurityPolicy is used to specify a security policy for opengcs to enforce. + WCOWSecurityPolicy = "io.microsoft.virtualmachine.wcow.securitypolicy" + + // WCOW SecurityPolicyEnforcer is used to specify which enforcer to initialize (open-door, standard or rego). + // This allows for better fallback mechanics. + WCOWSecurityPolicyEnforcer = "io.microsoft.virtualmachine.wcow.enforcer" // HclEnabled specifies whether to enable the host compatibility layer. HclEnabled = "io.microsoft.virtualmachine.lcow.hcl-enabled" @@ -290,6 +297,9 @@ const ( // UVMReferenceInfoFile specifies the filename of a signed UVM reference file to be passed to UVM. UVMReferenceInfoFile = "io.microsoft.virtualmachine.lcow.uvm-reference-info-file" + // UVMReferenceInfoFile specifies the filename of a signed UVM reference file to be passed to UVM. + WCOWUVMReferenceInfoFile = "io.microsoft.virtualmachine.wcow.uvm-reference-info-file" + // HostAMDCertificate specifies the filename of the AMD certificates to be passed to UVM. // The certificate is expected to be located in the same directory as the shim executable. HostAMDCertificate = "io.microsoft.virtualmachine.lcow.amd-certificate" diff --git a/pkg/cimfs/cim_test.go b/pkg/cimfs/cim_test.go index 7e194421c8..1ed0aa83db 100644 --- a/pkg/cimfs/cim_test.go +++ b/pkg/cimfs/cim_test.go @@ -5,6 +5,7 @@ package cimfs import ( "bytes" + "context" "errors" "fmt" "io" @@ -53,6 +54,14 @@ func (t *testBlockCIM) cimPath() string { return filepath.Join(t.BlockPath, t.CimName) } +type testVerifiedBlockCIM struct { + BlockCIM +} + +func (t *testVerifiedBlockCIM) cimPath() string { + return filepath.Join(t.BlockPath, t.CimName) +} + // A utility function to create a file/directory and write data to it in the given cim. func createCimFileUtil(c *CimFsWriter, fileTuple tuple) error { // create files inside the cim @@ -99,6 +108,8 @@ func openNewCIM(t *testing.T, newCIM testCIM) *CimFsWriter { writer, err = Create(val.imageDir, val.parentName, val.imageName) case *testBlockCIM: writer, err = CreateBlockCIM(val.BlockPath, val.CimName, val.Type) + case *testVerifiedBlockCIM: + writer, err = CreateBlockCIMWithOptions(context.Background(), &val.BlockCIM, WithDataIntegrity()) } if err != nil { t.Fatalf("failed while creating a cim: %s", err) @@ -666,3 +677,75 @@ func TestMergedLinksInMergedBlockCIMs(rootT *testing.T) { rootT.Logf("file contents don't match!") } } + +func TestVerifiedSingleFileBlockCIM(t *testing.T) { + if !IsVerifiedCimSupported() { + t.Skipf("verified CIMs are not supported") + } + + // contents to write to the CIM + testContents := []tuple{ + {"foo.txt", []byte("foo1"), false}, + {"bar.txt", []byte("bar"), false}, + } + + root := t.TempDir() + blockPath := filepath.Join(root, "layer.bcim") + tc := &testVerifiedBlockCIM{ + BlockCIM: BlockCIM{ + Type: BlockCIMTypeSingleFile, + BlockPath: blockPath, + CimName: "layer.cim", + }} + writer := openNewCIM(t, tc) + writeCIM(t, writer, testContents) + + mountvol := mountCIM(t, tc, CimMountVerifiedCim|CimMountSingleFileCim) + + compareContent(t, mountvol, testContents) +} + +func TestVerifiedSingleFileBlockCIMMount(t *testing.T) { + if !IsVerifiedCimSupported() { + t.Skipf("verified CIMs are not supported") + } + + // contents to write to the CIM + testContents := []tuple{ + {"foo.txt", []byte("foo1"), false}, + {"bar.txt", []byte("bar"), false}, + } + + root := t.TempDir() + blockPath := filepath.Join(root, "layer.bcim") + tc := &testVerifiedBlockCIM{ + BlockCIM: BlockCIM{ + Type: BlockCIMTypeSingleFile, + BlockPath: blockPath, + CimName: "layer.cim", + }} + writer := openNewCIM(t, tc) + writeCIM(t, writer, testContents) + + rootHash, err := GetVerificationInfo(blockPath) + if err != nil { + t.Fatalf("failed to get verification info: %s", err) + } + + // mount and read the contents of the cim + volumeGUID, err := guid.NewV4() + if err != nil { + t.Fatalf("generate cim mount GUID: %s", err) + } + + mountvol, err := MountVerifiedBlockCIM(&tc.BlockCIM, CimMountSingleFileCim, volumeGUID, rootHash) + if err != nil { + t.Fatalf("mount verified cim : %s", err) + } + t.Cleanup(func() { + if err := Unmount(mountvol); err != nil { + t.Logf("CIM unmount failed: %s", err) + } + }) + compareContent(t, mountvol, testContents) +} diff --git a/pkg/cimfs/cim_writer_windows.go b/pkg/cimfs/cim_writer_windows.go index 4204e87773..846718e62f 100644 --- a/pkg/cimfs/cim_writer_windows.go +++ b/pkg/cimfs/cim_writer_windows.go @@ -32,6 +32,8 @@ type CimFsWriter struct { activeStream winapi.StreamHandle // amount of bytes that can be written to the activeStream. activeLeft uint64 + // if true the CIM will be sealed after the writer is closed. + sealOnClose bool } // Create creates a new cim image. The CimFsWriter returned can then be used to do @@ -63,39 +65,108 @@ func Create(imagePath string, oldFSName string, newFSName string) (_ *CimFsWrite return &CimFsWriter{handle: handle, name: filepath.Join(imagePath, fsName)}, nil } -// Create creates a new block CIM and opens it for writing. The CimFsWriter -// returned can then be used to add/remove files to/from this CIM. -func CreateBlockCIM(blockPath, name string, blockType BlockCIMType) (_ *CimFsWriter, err error) { +// blockCIMConfig represents options for creating or merging block CIMs +type blockCIMConfig struct { + // ensures that the generted CIM is identical every time when created from the same source data. + // This is mostly required for image layers. Dissabled by default. + consistentCIM bool + // enables data integrity checking, which means the CIM will be verified and sealed on close. + // This is useful for ensuring that the CIM is tamper-proof. Disabled by default. + dataIntegrity bool +} + +// BlockCIMOpt is a function type for configuring block CIM creation options +type BlockCIMOpt func(*blockCIMConfig) error + +// enabled consistent CIM creation, this ensures that CIMs created from identical source data will always be identical (i.e. SHA256 digest of the CIM will remain same) +func WithConsistentCIM() BlockCIMOpt { + return func(opts *blockCIMConfig) error { + opts.consistentCIM = true + return nil + } +} + +// WithDataIntegrity enables data integrity checking (verified CIM with sealing on close) +func WithDataIntegrity() BlockCIMOpt { + return func(opts *blockCIMConfig) error { + opts.dataIntegrity = true + return nil + } +} + +// CreateBlockCIMWithOptions creates a new block CIM with the specified options and opens it for writing. +// The CimFsWriter returned can then be used to add/remove files to/from this CIM. +func CreateBlockCIMWithOptions(ctx context.Context, bCIM *BlockCIM, options ...BlockCIMOpt) (_ *CimFsWriter, err error) { + // Apply default options + config := &blockCIMConfig{} + + // Apply provided options + for _, option := range options { + option(config) + } + + // Validate options + if bCIM.BlockPath == "" || bCIM.CimName == "" { + return nil, fmt.Errorf("both blockPath & name must be non empty: %w", os.ErrInvalid) + } + + if bCIM.Type == BlockCIMTypeNone { + return nil, fmt.Errorf("invalid block CIM type `%d`: %w", bCIM.Type, os.ErrInvalid) + } + + // Check OS support if !IsBlockCimSupported() { return nil, fmt.Errorf("block CIM not supported on this OS version") } - if blockPath == "" || name == "" { - return nil, fmt.Errorf("both blockPath & name must be non empty: %w", os.ErrInvalid) + + if config.dataIntegrity && !IsVerifiedCimSupported() { + return nil, fmt.Errorf("verified CIMs are not supported on this OS version") } - // When creating block CIMs we always want them to be consistent CIMs i.e a CIMs - // created from the same layer tar will always be identical. - var createFlags uint32 = CimCreateFlagConsistentCim - switch blockType { + // Build create flags based on options + var createFlags uint32 + if config.consistentCIM { + createFlags |= CimCreateFlagConsistentCim + } + if config.dataIntegrity { + createFlags |= CimCreateFlagVerifiedCim + } + + switch bCIM.Type { case BlockCIMTypeDevice: createFlags |= CimCreateFlagBlockDeviceCim case BlockCIMTypeSingleFile: createFlags |= CimCreateFlagSingleFileCim default: - return nil, fmt.Errorf("invalid block CIM type `%d`: %w", blockType, os.ErrInvalid) + return nil, fmt.Errorf("invalid block CIM type `%d`: %w", bCIM.Type, os.ErrInvalid) } var newNameUTF16 *uint16 - newNameUTF16, err = windows.UTF16PtrFromString(name) + newNameUTF16, err = windows.UTF16PtrFromString(bCIM.CimName) if err != nil { return nil, err } var handle winapi.FsHandle - if err := winapi.CimCreateImage2(blockPath, createFlags, nil, newNameUTF16, &handle); err != nil { - return nil, fmt.Errorf("failed to create block CIM at path %s,%s: %w", blockPath, name, err) + if err := winapi.CimCreateImage2(bCIM.BlockPath, createFlags, nil, newNameUTF16, &handle); err != nil { + return nil, fmt.Errorf("failed to create block CIM at path %s,%s: %w", bCIM.BlockPath, bCIM.CimName, err) } - return &CimFsWriter{handle: handle, name: name}, nil + + return &CimFsWriter{ + handle: handle, + name: filepath.Join(bCIM.BlockPath, bCIM.CimName), + sealOnClose: config.dataIntegrity, // Seal on close if data integrity is enabled + }, nil +} + +// Create creates a new block CIM and opens it for writing. The CimFsWriter +// returned can then be used to add/remove files to/from this CIM. +func CreateBlockCIM(blockPath, name string, blockType BlockCIMType) (_ *CimFsWriter, err error) { + return CreateBlockCIMWithOptions(context.Background(), &BlockCIM{ + Type: blockType, + BlockPath: blockPath, + CimName: name, + }, WithConsistentCIM()) } // CreateAlternateStream creates alternate stream of given size at the given path inside the cim. This will @@ -268,7 +339,15 @@ func (c *CimFsWriter) Close() (err error) { } err = winapi.CimCloseImage(c.handle) c.handle = 0 - return err + if err != nil { + return &OpError{Cim: c.name, Op: "close", Err: err} + } + if c.sealOnClose { + if err = sealBlockCIM(filepath.Dir(c.name)); err != nil { + return &OpError{Cim: c.name, Op: "seal", Err: err} + } + } + return nil } // DestroyCim finds out the region files, object files of this cim and then delete the @@ -351,13 +430,27 @@ func GetCimUsage(ctx context.Context, cimPath string) (uint64, error) { // considered the base CIM. (i.e file with the same path in CIM at index 0 will shadow // files with the same path at all other CIMs) When mounting this merged CIM the source // CIMs MUST be provided in the exact same order. -func MergeBlockCIMs(mergedCIM *BlockCIM, sourceCIMs []*BlockCIM) (err error) { +func MergeBlockCIMsWithOpts(ctx context.Context, mergedCIM *BlockCIM, sourceCIMs []*BlockCIM, opts ...BlockCIMOpt) (err error) { if !IsMergedCimSupported() { return fmt.Errorf("merged CIMs aren't supported on this OS version") } else if len(sourceCIMs) < 2 { return fmt.Errorf("need at least 2 source CIMs, got %d: %w", len(sourceCIMs), os.ErrInvalid) } + // Apply default options + config := &blockCIMConfig{} + + // Apply provided options + for _, opt := range opts { + opt(config) + } + + for _, sCIM := range sourceCIMs { + if sCIM.Type != mergedCIM.Type { + return fmt.Errorf("source CIM (%s) type MUST match with merged CIM type: %w", sCIM.String(), os.ErrInvalid) + } + } + var mergeFlag uint32 switch mergedCIM.Type { case BlockCIMTypeDevice: @@ -368,13 +461,7 @@ func MergeBlockCIMs(mergedCIM *BlockCIM, sourceCIMs []*BlockCIM) (err error) { return fmt.Errorf("invalid block CIM type `%d`: %w", mergedCIM.Type, os.ErrInvalid) } - for _, sCIM := range sourceCIMs { - if sCIM.Type != mergedCIM.Type { - return fmt.Errorf("source CIM (%s) type doesn't match with merged CIM type: %w", sCIM.String(), os.ErrInvalid) - } - } - - cim, err := CreateBlockCIM(mergedCIM.BlockPath, mergedCIM.CimName, mergedCIM.Type) + cim, err := CreateBlockCIMWithOptions(ctx, mergedCIM, opts...) if err != nil { return fmt.Errorf("create merged CIM: %w", err) } @@ -395,3 +482,45 @@ func MergeBlockCIMs(mergedCIM *BlockCIM, sourceCIMs []*BlockCIM) (err error) { } return nil } + +// MergeBlockCIMs creates a new merged BlockCIM from the provided source BlockCIMs. CIM +// at index 0 is considered to be topmost CIM and the CIM at index `length-1` is +// considered the base CIM. (i.e file with the same path in CIM at index 0 will shadow +// files with the same path at all other CIMs) When mounting this merged CIM the source +// CIMs MUST be provided in the exact same order. +func MergeBlockCIMs(mergedCIM *BlockCIM, sourceCIMs []*BlockCIM) (err error) { + return MergeBlockCIMsWithOpts(context.Background(), mergedCIM, sourceCIMs, WithConsistentCIM()) +} + +// sealBlockCIM seals a blockCIM at the given path so that no further modifications are allowed on it. This also writes a +// root hash in the block header so that in future any reads happening on the CIM can be easily verified against this root hash +// to detect tampering. +func sealBlockCIM(blockPath string) error { + var hashSize, fixedHeaderSize uint64 + hashBuf := make([]byte, cimHashSize) + if err := winapi.CimSealImage(blockPath, &hashSize, &fixedHeaderSize, &hashBuf[0]); err != nil { + return fmt.Errorf("failed to seal block CIM: %w", err) + } else if hashSize != cimHashSize { + return fmt.Errorf("unexpected cim hash size %d", hashSize) + } + return nil +} + +// getDigest returns the digest of a sealed CIM. +func GetVerificationInfo(blockPath string) ([]byte, error) { + var ( + isSealed uint32 + hashSize uint64 + signatureSize uint64 + fixedHeaderSize uint64 + hash = make([]byte, cimHashSize) + ) + if err := winapi.CimGetVerificationInformation(blockPath, &isSealed, &hashSize, &signatureSize, &fixedHeaderSize, &hash[0], nil); err != nil { + return nil, fmt.Errorf("failed to get verification info from the CIM: %w", err) + } else if hashSize != cimHashSize { + return nil, fmt.Errorf("unexpected cim hash size %d", hashSize) + } else if isSealed == 0 { + return nil, fmt.Errorf("cim is not sealed") + } + return hash, nil +} diff --git a/pkg/cimfs/cimfs.go b/pkg/cimfs/cimfs.go index f301764387..57269607b0 100644 --- a/pkg/cimfs/cimfs.go +++ b/pkg/cimfs/cimfs.go @@ -31,6 +31,15 @@ func IsBlockCimSupported() bool { return build >= 27766 } +// IsVerifiedCimSupported returns true if block CIM format supports also writing verification information in the CIM. +func IsVerifiedCimSupported() bool { + build := osversion.Build() + // TODO(ambarve): Currently we are checking against a higher build number since there is no + // official build with block CIM support yet. Once we have that build, we should + // update the build number here. + return build >= 27800 +} + func IsMergedCimSupported() bool { // The merged CIM support was originally added before block CIM support. However, // some of the merged CIM features that we use (e.g. merged hard links) were added @@ -49,6 +58,7 @@ const ( CimMountFlagEnableDax uint32 = 0x2 CimMountBlockDeviceCim uint32 = 0x10 CimMountSingleFileCim uint32 = 0x20 + CimMountVerifiedCim uint32 = 0x80 CimCreateFlagNone uint32 = 0x0 CimCreateFlagDoNotExpandPEImages uint32 = 0x1 @@ -56,6 +66,7 @@ const ( CimCreateFlagBlockDeviceCim uint32 = 0x4 CimCreateFlagSingleFileCim uint32 = 0x8 CimCreateFlagConsistentCim uint32 = 0x10 + CimCreateFlagVerifiedCim uint32 = 0x40 CimMergeFlagNone uint32 = 0x0 CimMergeFlagSingleFile uint32 = 0x1 diff --git a/pkg/cimfs/common.go b/pkg/cimfs/common.go index 0a05f5a9d2..ab988aff06 100644 --- a/pkg/cimfs/common.go +++ b/pkg/cimfs/common.go @@ -15,6 +15,10 @@ import ( "github.com/Microsoft/hcsshim/pkg/cimfs/format" ) +const ( + cimHashSize = 32 // size of a hash of a verified CIM in bytes +) + var ( // Equivalent to SDDL of "D:NO_ACCESS_CONTROL". nullSd = []byte{1, 0, 4, 128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} diff --git a/pkg/cimfs/doc.go b/pkg/cimfs/doc.go index bb9ce57717..c40f38cc69 100644 --- a/pkg/cimfs/doc.go +++ b/pkg/cimfs/doc.go @@ -24,6 +24,16 @@ newFSName string) (_ *CimFsWriter, err error)` function defined in this package, block CIMs can be created with the `func CreateBlockCIM(blockPath, oldName, newName string, blockType BlockCIMType) (_ *CimFsWriter, err error)` function. +Verified CIMs: +A block CIM can also provide integrity checking (via a hash/Merkel tree, +similar to dm-verity on Linux). If a CIM is written and sealed, it generates a +root hash of all of its contents and shares it back with the client. Any +verified CIM can be mounted by passing a hash that we expect to be its root +hash. All read operations on such a mounted CIM will then validate that the +generated root hash matches with the one that was provided at mount time. If it +doesn't match the read fails. This allows us to guarantee that the CIM based +layered aren't being modified underneath us. + Forking & Merging CIMs: In container world, CIMs are used for storing container image layers. Usually, one layer is stored in one CIM. This means we need a way to combine multiple CIMs to create the diff --git a/pkg/cimfs/mount_cim.go b/pkg/cimfs/mount_cim.go index 8588d63b34..905357eb6c 100644 --- a/pkg/cimfs/mount_cim.go +++ b/pkg/cimfs/mount_cim.go @@ -15,6 +15,10 @@ import ( "golang.org/x/sys/windows" ) +const ( + VolumePathFormat = "\\\\?\\Volume{%s}\\" +) + type MountError struct { Cim string Op string @@ -116,5 +120,34 @@ func MountMergedBlockCIMs(mergedCIM *BlockCIM, sourceCIMs []*BlockCIM, mountFlag if err := winapi.CimMergeMountImage(uint32(len(cimsToMerge)), &cimsToMerge[0], mountFlags, &volumeGUID); err != nil { return "", &MountError{Cim: filepath.Join(mergedCIM.BlockPath, mergedCIM.CimName), Op: "MountMerged", Err: err} } + return fmt.Sprintf(VolumePathFormat, volumeGUID.String()), nil +} + +// Mounts a verified block CIM with the provided root hash. The root hash is usually +// returned when the CIM is sealed or the root hash can be queries from a block CIM. +// Every read on the mounted volume will be verified to match against the provided root +// hash if it doesn't, the read will fail. The CIM MUST have been created with the +// verified creation flag. +func MountVerifiedBlockCIM(bCIM *BlockCIM, mountFlags uint32, volumeGUID guid.GUID, rootHash []byte) (string, error) { + if len(rootHash) != cimHashSize { + return "", fmt.Errorf("unexpected root hash size %d, expected size is %d", len(rootHash), cimHashSize) + } + + // The CimMountVerifiedCim flag should only be used when using the regular mount + // CIM API. That flag is required to tell that API that this is a verified + // CIM. This API doesn't need that flag as it is already assumed that the CIM is + // verified. + switch bCIM.Type { + case BlockCIMTypeDevice: + mountFlags |= CimMountBlockDeviceCim + case BlockCIMTypeSingleFile: + mountFlags |= CimMountSingleFileCim + default: + return "", fmt.Errorf("invalid block CIM type `%d`: %w", bCIM.Type, os.ErrInvalid) + } + + if err := winapi.CimMountVerifiedImage(bCIM.BlockPath, bCIM.CimName, mountFlags, &volumeGUID, cimHashSize, &rootHash[0]); err != nil { + return "", &MountError{Cim: bCIM.String(), Op: "MountVerifiedCIM", Err: err} + } return fmt.Sprintf("\\\\?\\Volume{%s}\\", volumeGUID.String()), nil } diff --git a/pkg/securitypolicy/api.rego b/pkg/securitypolicy/api.rego index 82e79c9040..36a197ebc2 100644 --- a/pkg/securitypolicy/api.rego +++ b/pkg/securitypolicy/api.rego @@ -5,6 +5,7 @@ version := "@@API_VERSION@@" enforcement_points := { "mount_device": {"introducedVersion": "0.1.0", "default_results": {"allowed": false}}, "mount_overlay": {"introducedVersion": "0.1.0", "default_results": {"allowed": false}}, + "mount_cims": {"introducedVersion": "0.11.0", "default_results": {"allowed": false}}, "create_container": {"introducedVersion": "0.1.0", "default_results": {"allowed": false, "env_list": null, "allow_stdio_access": false}}, "unmount_device": {"introducedVersion": "0.2.0", "default_results": {"allowed": true}}, "unmount_overlay": {"introducedVersion": "0.6.0", "default_results": {"allowed": true}}, diff --git a/pkg/securitypolicy/framework.rego b/pkg/securitypolicy/framework.rego index ffca157147..830d6015d0 100644 --- a/pkg/securitypolicy/framework.rego +++ b/pkg/securitypolicy/framework.rego @@ -57,6 +57,14 @@ layerPaths_ok(layers) { } } +layerHashes_ok(layers) { + length := count(layers) + count(input.layerHashes) == length + every i, hash in input.layerHashes { + layers[(length - i) - 1] == hash + } +} + default overlay_exists := false overlay_exists { @@ -95,6 +103,25 @@ candidate_containers := containers { containers := array.concat(policy_containers, fragment_containers) } +default mount_cims := {"allowed": false} + +mount_cims := {"metadata": [addMatches], "allowed": true} { + not overlay_exists + + containers := [container | + container := candidate_containers[_] + layerHashes_ok(container.layers) + ] + + count(containers) > 0 + addMatches := { + "name": "matches", + "action": "add", + "key": input.containerID, + "value": containers, + } +} + default mount_overlay := {"allowed": false} mount_overlay := {"metadata": [addMatches, addOverlayTarget], "allowed": true} { @@ -223,23 +250,37 @@ workingDirectory_ok(working_dir) { } privileged_ok(elevation_allowed) { + is_linux not input.privileged } privileged_ok(elevation_allowed) { + is_linux input.privileged input.privileged == elevation_allowed } +privileged_ok(no_new_privileges) { + # no-op for windows + is_windows +} + noNewPrivileges_ok(no_new_privileges) { + is_linux no_new_privileges input.noNewPrivileges } noNewPrivileges_ok(no_new_privileges) { + is_linux no_new_privileges == false } +noNewPrivileges_ok(no_new_privileges) { + # no-op for windows + is_windows +} + idName_ok(pattern, "any", value) { true } @@ -257,6 +298,7 @@ idName_ok(pattern, "re2", value) { } user_ok(user) { + is_linux user.umask == input.umask idName_ok(user.user_idname.pattern, user.user_idname.strategy, input.user) every group in input.groups { @@ -265,10 +307,20 @@ user_ok(user) { } } +user_ok(user) { + is_windows + input.user == user +} + seccomp_ok(seccomp_profile_sha256) { + is_linux input.seccompProfileSHA256 == seccomp_profile_sha256 } +seccomp_ok(seccomp_profile_sha256) { + is_windows +} + default container_started := false container_started { @@ -378,6 +430,7 @@ all_caps_sets_are_equal(sets) := caps { } valid_caps_for_all(containers, privileged) := caps { + is_linux allow_capability_dropping # find largest matching capabilities sets aka "the most specific" @@ -389,13 +442,21 @@ valid_caps_for_all(containers, privileged) := caps { } valid_caps_for_all(containers, privileged) := caps { + is_linux not allow_capability_dropping # no dropping allowed, so we just return the input caps := input.capabilities } +valid_caps_for_all(containers, privileged) := caps { + # no-op for windows + is_windows + caps := input.capabilities +} + caps_ok(allowed_caps, requested_caps) { + is_linux capsList_ok(allowed_caps.bounding, requested_caps.bounding) capsList_ok(allowed_caps.effective, requested_caps.effective) capsList_ok(allowed_caps.inheritable, requested_caps.inheritable) @@ -403,6 +464,10 @@ caps_ok(allowed_caps, requested_caps) { capsList_ok(allowed_caps.ambient, requested_caps.ambient) } +caps_ok(allowed_caps, requested_caps) { + is_windows +} + get_capabilities(container, privileged) := capabilities { container.capabilities != null capabilities := container.capabilities @@ -487,11 +552,10 @@ create_container := {"metadata": [updateMatches, addStarted], # check to see if the capabilities variables match, dropping # them if allowed (and necessary) - caps_list := valid_caps_for_all(possible_after_env_containers, input.privileged) - possible_after_caps_containers := [container | - container := possible_after_env_containers[_] - caps_ok(get_capabilities(container, input.privileged), caps_list) - ] + caps_result := possible_container_after_caps(possible_after_env_containers, input.privileged) + + possible_after_caps_containers := caps_result.containers + caps_list := caps_result.caps_list count(possible_after_caps_containers) > 0 @@ -523,6 +587,24 @@ create_container := {"metadata": [updateMatches, addStarted], }, } } +possible_container_after_caps(env_containers, privileged) := { + "containers": env_containers, + "caps_list": [] +} { + is_windows +} + +possible_container_after_caps(env_containers, privileged) := { + "containers": filtered, + "caps_list": caps_list +} { + is_linux + caps_list := valid_caps_for_all(env_containers, privileged) + filtered := [container | + container := env_containers[_] + caps_ok(get_capabilities(container, privileged), caps_list) + ] +} mountSource_ok(constraint, source) { startswith(constraint, data.sandboxPrefix) @@ -585,10 +667,23 @@ mount_ok(mounts, allow_elevated, mount) { } mountList_ok(mounts, allow_elevated) { + is_linux every mount in input.mounts { mount_ok(mounts, allow_elevated, mount) } } +mountList_ok(mounts, allow_elevated) { + # no-op for windows + is_windows +} + +is_linux { + data.metadata.operatingsystem[ostype] == "linux" +} + +is_windows { + data.metadata.operatingsystem[ostype] == "windows" +} default exec_in_container := {"allowed": false} @@ -625,11 +720,10 @@ exec_in_container := {"metadata": [updateMatches], # check to see if the capabilities variables match, dropping # them if allowed (and necessary) - caps_list := valid_caps_for_all(possible_after_env_containers, container_privileged) - possible_after_caps_containers := [container | - container := possible_after_env_containers[_] - caps_ok(get_capabilities(container, container_privileged), caps_list) - ] + caps_result := possible_container_after_caps(possible_after_env_containers, container_privileged) + + possible_after_caps_containers := caps_result.containers + caps_list := caps_result.caps_list count(possible_after_caps_containers) > 0 @@ -1085,6 +1179,7 @@ privileged_matches { } errors["privileged escalation not allowed"] { + is_linux input.rule in ["create_container"] not privileged_matches } @@ -1262,6 +1357,7 @@ mount_matches(mount) { } errors[mountError] { + is_linux input.rule == "create_container" bad_mounts := [mount.destination | mount := input.mounts[_] @@ -1410,6 +1506,7 @@ errors[fragment_framework_version_error] { } errors["containers only distinguishable by allow_stdio_access"] { + is_linux input.rule == "create_container" not container_started @@ -1499,6 +1596,7 @@ noNewPrivileges_matches { } errors["invalid noNewPrivileges"] { + is_linux input.rule in ["create_container", "exec_in_container"] not noNewPrivileges_matches } @@ -1521,11 +1619,13 @@ user_matches { } errors["invalid user"] { + is_linux input.rule in ["create_container", "exec_in_container"] not user_matches } errors["capabilities don't match"] { + is_linux input.rule == "create_container" not container_started @@ -1565,6 +1665,7 @@ errors["capabilities don't match"] { } errors["capabilities don't match"] { + is_linux input.rule == "exec_in_container" container_started @@ -1604,6 +1705,7 @@ errors["capabilities don't match"] { # covers exec_in_container as well. it shouldn't be possible to ever get # an exec_in_container as it "inherits" capabilities rules from create_container errors["containers only distinguishable by capabilties"] { + is_linux input.rule == "create_container" allow_capability_dropping @@ -1649,6 +1751,7 @@ seccomp_matches { } errors["invalid seccomp"] { + is_linux input.rule == "create_container" not seccomp_matches } diff --git a/pkg/securitypolicy/open_door.rego b/pkg/securitypolicy/open_door.rego index 2bc36123d8..a8e283092d 100644 --- a/pkg/securitypolicy/open_door.rego +++ b/pkg/securitypolicy/open_door.rego @@ -5,6 +5,7 @@ api_version := "@@API_VERSION@@" mount_device := {"allowed": true} mount_overlay := {"allowed": true} create_container := {"allowed": true, "env_list": null, "allow_stdio_access": true} +mount_cims := {"allowed": true} unmount_device := {"allowed": true} unmount_overlay := {"allowed": true} exec_in_container := {"allowed": true, "env_list": null} diff --git a/pkg/securitypolicy/policy.rego b/pkg/securitypolicy/policy.rego index 139d340048..9414116c19 100644 --- a/pkg/securitypolicy/policy.rego +++ b/pkg/securitypolicy/policy.rego @@ -9,6 +9,7 @@ mount_device := data.framework.mount_device unmount_device := data.framework.unmount_device mount_overlay := data.framework.mount_overlay unmount_overlay := data.framework.unmount_overlay +mount_cims:= data.framework.mount_cims create_container := data.framework.create_container exec_in_container := data.framework.exec_in_container exec_external := data.framework.exec_external diff --git a/pkg/securitypolicy/regopolicy_test.go b/pkg/securitypolicy/regopolicy_test.go index fa6a9560ca..da75863595 100644 --- a/pkg/securitypolicy/regopolicy_test.go +++ b/pkg/securitypolicy/regopolicy_test.go @@ -37,6 +37,7 @@ const ( maxGeneratedFragmentIssuerLength = 16 maxPlan9MountTargetLength = 64 maxPlan9MountIndex = 16 + osType = "linux" ) func Test_RegoTemplates(t *testing.T) { @@ -146,7 +147,8 @@ func Test_MarshalRego_Policy(t *testing.T) { return false } - _, err = newRegoPolicy(expected, defaultMounts, privilegedMounts) + _, err = newRegoPolicy(expected, defaultMounts, privilegedMounts, osType) + if err != nil { t.Errorf("unable to convert policy to rego: %v", err) return false @@ -232,7 +234,8 @@ func Test_MarshalRego_Fragment(t *testing.T) { func Test_Rego_EnforceDeviceMountPolicy_No_Matches(t *testing.T) { f := func(p *generatedConstraints) bool { securityPolicy := p.toPolicy() - policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}) + policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, osType) + if err != nil { t.Errorf("unable to convert policy to rego: %v", err) return false @@ -257,7 +260,8 @@ func Test_Rego_EnforceDeviceMountPolicy_No_Matches(t *testing.T) { func Test_Rego_EnforceDeviceMountPolicy_Matches(t *testing.T) { f := func(p *generatedConstraints) bool { securityPolicy := p.toPolicy() - policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}) + policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, osType) + if err != nil { t.Errorf("unable to convert policy to rego: %v", err) return false @@ -280,7 +284,8 @@ func Test_Rego_EnforceDeviceMountPolicy_Matches(t *testing.T) { func Test_Rego_EnforceDeviceUmountPolicy_Removes_Device_Entries(t *testing.T) { f := func(p *generatedConstraints) bool { securityPolicy := p.toPolicy() - policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}) + policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, osType) + if err != nil { t.Error(err) return false @@ -318,7 +323,8 @@ func Test_Rego_EnforceDeviceUmountPolicy_Removes_Device_Entries(t *testing.T) { func Test_Rego_EnforceDeviceMountPolicy_Duplicate_Device_Target(t *testing.T) { f := func(p *generatedConstraints) bool { securityPolicy := p.toPolicy() - policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}) + policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, osType) + if err != nil { t.Errorf("unable to convert policy to rego: %v", err) return false @@ -413,7 +419,8 @@ func Test_Rego_EnforceOverlayMountPolicy_Layers_With_Same_Root_Hash(t *testing.T constraints.containers = []*securityPolicyContainer{container} constraints.externalProcesses = generateExternalProcesses(testRand) securityPolicy := constraints.toPolicy() - policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}) + policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, osType) + if err != nil { t.Fatal("Unable to create security policy") } @@ -449,7 +456,8 @@ func Test_Rego_EnforceOverlayMountPolicy_Layers_Shared_Layers(t *testing.T) { constraints.externalProcesses = generateExternalProcesses(testRand) securityPolicy := constraints.toPolicy() - policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}) + policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, osType) + if err != nil { t.Fatal("Unable to create security policy") } @@ -559,7 +567,7 @@ func Test_Rego_EnforceOverlayMountPolicy_Reusing_ID_Across_Overlays(t *testing.T policy, err := newRegoPolicy(securityPolicy.marshalRego(), toOCIMounts(defaultMounts), - toOCIMounts(privilegedMounts)) + toOCIMounts(privilegedMounts), osType) if err != nil { t.Fatal(err) } @@ -611,7 +619,8 @@ func Test_Rego_EnforceOverlayMountPolicy_Multiple_Instances_Same_Container(t *te } securityPolicy := constraints.toPolicy() - policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}) + policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, osType) + if err != nil { t.Fatalf("failed create enforcer") } @@ -899,7 +908,7 @@ func Test_Rego_EnforceCreateContainer_Start_All_Containers(t *testing.T) { policy, err := newRegoPolicy(securityPolicy.marshalRego(), toOCIMounts(defaultMounts), - toOCIMounts(privilegedMounts)) + toOCIMounts(privilegedMounts), osType) if err != nil { t.Error(err) return false @@ -1779,7 +1788,8 @@ func Test_Rego_MountPolicy_MountPrivilegedWhenNotAllowed(t *testing.T) { func Test_Rego_Version_Unregistered_Enforcement_Point(t *testing.T) { gc := generateConstraints(testRand, maxContainersInGeneratedConstraints) securityPolicy := gc.toPolicy() - policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}) + policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, osType) + if err != nil { t.Fatalf("unable to create a new Rego policy: %v", err) } @@ -1800,7 +1810,8 @@ func Test_Rego_Version_Unregistered_Enforcement_Point(t *testing.T) { func Test_Rego_Version_Future_Enforcement_Point(t *testing.T) { gc := generateConstraints(testRand, maxContainersInGeneratedConstraints) securityPolicy := gc.toPolicy() - policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}) + policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, osType) + if err != nil { t.Fatalf("unable to create a new Rego policy: %v", err) } @@ -1829,7 +1840,8 @@ func Test_Rego_Version_Future_Enforcement_Point(t *testing.T) { // by their respective version information. func Test_Rego_Version_Unavailable_Enforcement_Point(t *testing.T) { code := "package policy\n\napi_version := \"0.0.1\"" - policy, err := newRegoPolicy(code, []oci.Mount{}, []oci.Mount{}) + policy, err := newRegoPolicy(code, []oci.Mount{}, []oci.Mount{}, osType) + if err != nil { t.Fatalf("unable to create a new Rego policy: %v", err) } @@ -1862,7 +1874,8 @@ func Test_Rego_Version_Unavailable_Enforcement_Point(t *testing.T) { func Test_Rego_Enforcement_Point_Allowed(t *testing.T) { code := "package policy\n\napi_version := \"0.0.1\"" - policy, err := newRegoPolicy(code, []oci.Mount{}, []oci.Mount{}) + policy, err := newRegoPolicy(code, []oci.Mount{}, []oci.Mount{}, osType) + if err != nil { t.Fatalf("unable to create a new Rego policy: %v", err) } @@ -1913,7 +1926,8 @@ api_version := "0.0.1" __fixture_for_allowed_extra__ := {"allowed": true} ` - policy, err := newRegoPolicy(code, []oci.Mount{}, []oci.Mount{}) + policy, err := newRegoPolicy(code, []oci.Mount{}, []oci.Mount{}, osType) + if err != nil { t.Fatalf("unable to create a new Rego policy: %v", err) } @@ -1950,7 +1964,8 @@ __fixture_for_allowed_extra__ := {"allowed": true} func Test_Rego_No_API_Version(t *testing.T) { code := "package policy" - policy, err := newRegoPolicy(code, []oci.Mount{}, []oci.Mount{}) + policy, err := newRegoPolicy(code, []oci.Mount{}, []oci.Mount{}, osType) + if err != nil { t.Fatalf("unable to create a new Rego policy: %v", err) } @@ -2433,7 +2448,8 @@ exec_external := { strings.Join(generateEnvs(envSet), `","`), strings.Join(generateEnvs(envSet), `","`)) - policy, err := newRegoPolicy(rego, []oci.Mount{}, []oci.Mount{}) + policy, err := newRegoPolicy(rego, []oci.Mount{}, []oci.Mount{}, osType) + if err != nil { t.Errorf("error creating policy: %v", err) return false @@ -2490,7 +2506,8 @@ func Test_Rego_InvalidEnvList(t *testing.T) { "env_list": true }`, apiVersion, frameworkVersion) - policy, err := newRegoPolicy(rego, []oci.Mount{}, []oci.Mount{}) + policy, err := newRegoPolicy(rego, []oci.Mount{}, []oci.Mount{}, osType) + if err != nil { t.Fatalf("error creating policy: %v", err) } @@ -2539,7 +2556,8 @@ func Test_Rego_InvalidEnvList_Member(t *testing.T) { "env_list": ["one", ["two"], "three"] }`, apiVersion, frameworkVersion) - policy, err := newRegoPolicy(rego, []oci.Mount{}, []oci.Mount{}) + policy, err := newRegoPolicy(rego, []oci.Mount{}, []oci.Mount{}, osType) + if err != nil { t.Fatalf("error creating policy: %v", err) } @@ -2796,7 +2814,8 @@ func Test_Rego_ExecExternalProcessPolicy_DropEnvs_Multiple(t *testing.T) { policy, err := newRegoPolicy(securityPolicy.marshalRego(), toOCIMounts(defaultMounts), - toOCIMounts(privilegedMounts)) + toOCIMounts(privilegedMounts), + osType) if err != nil { t.Fatal(err) } @@ -2840,7 +2859,8 @@ func Test_Rego_ExecExternalProcessPolicy_DropEnvs_Multiple_NoMatch(t *testing.T) policy, err := newRegoPolicy(securityPolicy.marshalRego(), toOCIMounts(defaultMounts), - toOCIMounts(privilegedMounts)) + toOCIMounts(privilegedMounts), + osType) if err != nil { t.Fatal(err) } @@ -3796,7 +3816,7 @@ func Test_Rego_LoadFragment_SemverVersion(t *testing.T) { defaultMounts := toOCIMounts(generateMounts(testRand)) privilegedMounts := toOCIMounts(generateMounts(testRand)) - policy, err := newRegoPolicy(securityPolicy.marshalRego(), defaultMounts, privilegedMounts) + policy, err := newRegoPolicy(securityPolicy.marshalRego(), defaultMounts, privilegedMounts, osType) if err != nil { t.Fatalf("error compiling policy: %v", err) @@ -4159,7 +4179,8 @@ load_fragment := {"allowed": true, "add_module": true} { mount_device := data.fragment.mount_device `, apiVersion, frameworkVersion, issuer, feed) - policy, err := newRegoPolicy(policyCode, []oci.Mount{}, []oci.Mount{}) + policy, err := newRegoPolicy(policyCode, []oci.Mount{}, []oci.Mount{}, osType) + if err != nil { t.Fatalf("unable to create Rego policy: %v", err) } @@ -4441,7 +4462,8 @@ func Test_Rego_ExecExternal_StdioAccess_NotAllowed(t *testing.T) { gc.externalProcesses = append(gc.externalProcesses, gc.externalProcesses[0].clone()) gc.externalProcesses[0].allowStdioAccess = !gc.externalProcesses[0].allowStdioAccess - policy, err := newRegoPolicy(gc.toPolicy().marshalRego(), []oci.Mount{}, []oci.Mount{}) + policy, err := newRegoPolicy(gc.toPolicy().marshalRego(), []oci.Mount{}, []oci.Mount{}, osType) + if err != nil { t.Fatalf("error marshaling policy: %v", err) } @@ -4851,7 +4873,8 @@ func Test_Rego_MissingEnvList(t *testing.T) { exec_external := {"allowed": true} `, apiVersion) - policy, err := newRegoPolicy(code, []oci.Mount{}, []oci.Mount{}) + policy, err := newRegoPolicy(code, []oci.Mount{}, []oci.Mount{}, osType) + if err != nil { t.Fatalf("error compiling the rego policy: %v", err) } @@ -5004,7 +5027,8 @@ func Test_Rego_ExecExternalProcessPolicy_ConflictingAllowStdioAccessHasErrorMess policy, err := newRegoPolicy(securityPolicy.marshalRego(), toOCIMounts(defaultMounts), - toOCIMounts(privilegedMounts)) + toOCIMounts(privilegedMounts), + osType) if err != nil { t.Fatal(err) } @@ -5121,7 +5145,8 @@ func Test_Rego_ExecExternalProcessPolicy_RequiredEnvMissingHasErrorMessage(t *te policy, err := newRegoPolicy(securityPolicy.marshalRego(), toOCIMounts(defaultMounts), - toOCIMounts(privilegedMounts)) + toOCIMounts(privilegedMounts), + osType) if err != nil { t.Fatal(err) } @@ -5597,7 +5622,8 @@ func Test_Rego_FrameworkSVN(t *testing.T) { policy, err := newRegoPolicy(code, toOCIMounts(defaultMounts), - toOCIMounts(privilegedMounts)) + toOCIMounts(privilegedMounts), + osType) if err != nil { t.Fatalf("unable to create policy: %v", err) } @@ -5627,7 +5653,7 @@ func Test_Rego_Fragment_FrameworkSVN(t *testing.T) { defaultMounts := toOCIMounts(generateMounts(testRand)) privilegedMounts := toOCIMounts(generateMounts(testRand)) - policy, err := newRegoPolicy(securityPolicy.marshalRego(), defaultMounts, privilegedMounts) + policy, err := newRegoPolicy(securityPolicy.marshalRego(), defaultMounts, privilegedMounts, osType) if err != nil { t.Fatalf("error compiling policy: %v", err) @@ -5675,7 +5701,8 @@ func Test_Rego_APISVN(t *testing.T) { policy, err := newRegoPolicy(code, toOCIMounts(defaultMounts), - toOCIMounts(privilegedMounts)) + toOCIMounts(privilegedMounts), + osType) if err != nil { t.Fatalf("unable to create policy: %v", err) } @@ -5702,7 +5729,8 @@ func Test_Rego_NoReason(t *testing.T) { mount_device := {"allowed": false} ` - policy, err := newRegoPolicy(code, []oci.Mount{}, []oci.Mount{}) + policy, err := newRegoPolicy(code, []oci.Mount{}, []oci.Mount{}, osType) + if err != nil { t.Fatalf("unable to create policy: %v", err) } @@ -5795,7 +5823,8 @@ func Test_Rego_ErrorTruncation_CustomPolicy(t *testing.T) { reason := {"custom_error": "%s"} `, randString(testRand, 2048)) - policy, err := newRegoPolicy(code, []oci.Mount{}, []oci.Mount{}) + policy, err := newRegoPolicy(code, []oci.Mount{}, []oci.Mount{}, osType) + if err != nil { t.Fatalf("unable to create policy: %v", err) } @@ -5823,7 +5852,8 @@ func Test_Rego_Missing_Enforcement_Point(t *testing.T) { reason := {"errors": data.framework.errors} ` - policy, err := newRegoPolicy(code, []oci.Mount{}, []oci.Mount{}) + policy, err := newRegoPolicy(code, []oci.Mount{}, []oci.Mount{}, osType) + if err != nil { t.Fatalf("unable to create policy: %v", err) } @@ -6101,7 +6131,8 @@ type regoOverlayTestConfig struct { func setupRegoOverlayTest(gc *generatedConstraints, valid bool) (tc *regoOverlayTestConfig, err error) { securityPolicy := gc.toPolicy() - policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}) + policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, osType) + if err != nil { return nil, err } @@ -6164,7 +6195,8 @@ func setupRegoCreateContainerTest(gc *generatedConstraints, testContainer *secur policy, err := newRegoPolicy(securityPolicy.marshalRego(), toOCIMounts(defaultMounts), - toOCIMounts(privilegedMounts)) + toOCIMounts(privilegedMounts), + osType) if err != nil { return nil, err } @@ -6230,7 +6262,8 @@ func setupRegoRunningContainerTest(gc *generatedConstraints, privileged bool) (t policy, err := newRegoPolicy(securityPolicy.marshalRego(), toOCIMounts(defaultMounts), - toOCIMounts(privilegedMounts)) + toOCIMounts(privilegedMounts), + osType) if err != nil { return nil, err } @@ -6331,7 +6364,8 @@ func setupExternalProcessTest(gc *generatedConstraints) (tc *regoExternalPolicyT policy, err := newRegoPolicy(securityPolicy.marshalRego(), toOCIMounts(defaultMounts), - toOCIMounts(privilegedMounts)) + toOCIMounts(privilegedMounts), + osType) if err != nil { return nil, err } @@ -6358,7 +6392,8 @@ func setupPlan9MountTest(gc *generatedConstraints) (tc *regoPlan9MountTestConfig policy, err := newRegoPolicy(securityPolicy.marshalRego(), toOCIMounts(defaultMounts), - toOCIMounts(privilegedMounts)) + toOCIMounts(privilegedMounts), + osType) if err != nil { return nil, err } @@ -6439,7 +6474,8 @@ func setupGetPropertiesTest(gc *generatedConstraints, allowPropertiesAccess bool policy, err := newRegoPolicy(securityPolicy.marshalRego(), toOCIMounts(defaultMounts), - toOCIMounts(privilegedMounts)) + toOCIMounts(privilegedMounts), + osType) if err != nil { return nil, err } @@ -6462,7 +6498,8 @@ func setupDumpStacksTest(constraints *generatedConstraints, allowDumpStacks bool policy, err := newRegoPolicy(securityPolicy.marshalRego(), toOCIMounts(defaultMounts), - toOCIMounts(privilegedMounts)) + toOCIMounts(privilegedMounts), + osType) if err != nil { return nil, err } @@ -6500,7 +6537,8 @@ type regoPolicyOnlyTestConfig struct { func setupRegoPolicyOnlyTest(gc *generatedConstraints) (tc *regoPolicyOnlyTestConfig, err error) { securityPolicy := gc.toPolicy() - policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}) + policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, osType) + if err != nil { return nil, err } @@ -6664,7 +6702,7 @@ func setupRegoFragmentTestConfig(gc *generatedConstraints, numFragments int, inc securityPolicy := gc.toPolicy() defaultMounts := toOCIMounts(generateMounts(testRand)) privilegedMounts := toOCIMounts(generateMounts(testRand)) - policy, err := newRegoPolicy(securityPolicy.marshalRego(), defaultMounts, privilegedMounts) + policy, err := newRegoPolicy(securityPolicy.marshalRego(), defaultMounts, privilegedMounts, osType) if err != nil { return nil, err @@ -6748,7 +6786,8 @@ func setupRegoDropEnvsTest(disjoint bool) (*regoContainerTestConfig, error) { policy, err := newRegoPolicy(securityPolicy.marshalRego(), toOCIMounts(defaultMounts), - toOCIMounts(privilegedMounts)) + toOCIMounts(privilegedMounts), + osType) if err != nil { return nil, err @@ -6842,7 +6881,8 @@ func setupFrameworkVersionTest(gc *generatedConstraints, policyVersion string, v } securityPolicy := gc.toPolicy() - policy, err := newRegoPolicy(setFrameworkVersion(securityPolicy.marshalRego(), policyVersion), []oci.Mount{}, []oci.Mount{}) + policy, err := newRegoPolicy(setFrameworkVersion(securityPolicy.marshalRego(), policyVersion), []oci.Mount{}, []oci.Mount{}, osType) + if err != nil { return nil, err } @@ -7286,7 +7326,8 @@ func setupRegoScratchMountTest( defaultMounts := generateMounts(testRand) privilegedMounts := generateMounts(testRand) - policy, err := newRegoPolicy(securityPolicy.marshalRego(), toOCIMounts(defaultMounts), toOCIMounts(privilegedMounts)) + policy, err := newRegoPolicy(securityPolicy.marshalRego(), toOCIMounts(defaultMounts), toOCIMounts(privilegedMounts), osType) + if err != nil { return nil, err } diff --git a/pkg/securitypolicy/securitypolicy_linux.go b/pkg/securitypolicy/securitypolicy_linux.go new file mode 100644 index 0000000000..eb4a01848c --- /dev/null +++ b/pkg/securitypolicy/securitypolicy_linux.go @@ -0,0 +1,144 @@ +//go:build linux +// +build linux + +package securitypolicy + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strconv" + "strings" + + specGuest "github.com/Microsoft/hcsshim/internal/guest/spec" + specInternal "github.com/Microsoft/hcsshim/internal/guest/spec" + "github.com/Microsoft/hcsshim/internal/guestpath" + "github.com/Microsoft/hcsshim/internal/log" + "github.com/moby/sys/user" + oci "github.com/opencontainers/runtime-spec/specs-go" + "github.com/pkg/errors" +) + +// This is being used by StandEnforcer. +// substituteUVMPath substitutes mount prefix to an appropriate path inside +// UVM. At policy generation time, it's impossible to tell what the sandboxID +// will be, so the prefix substitution needs to happen during runtime. +func substituteUVMPath(sandboxID string, m mountInternal) mountInternal { + if strings.HasPrefix(m.Source, guestpath.SandboxMountPrefix) { + m.Source = specInternal.SandboxMountSource(sandboxID, m.Source) + } else if strings.HasPrefix(m.Source, guestpath.HugePagesMountPrefix) { + m.Source = specInternal.HugePagesMountSource(sandboxID, m.Source) + } + return m +} + +// SandboxMountsDir returns sandbox mounts directory inside UVM/host. +func SandboxMountsDir(sandboxID string) string { + return specInternal.SandboxMountsDir((sandboxID)) +} + +// HugePagesMountsDir returns hugepages mounts directory inside UVM. +func HugePagesMountsDir(sandboxID string) string { + return specInternal.HugePagesMountsDir(sandboxID) +} + +func getUser(passwdPath string, filter func(user.User) bool) (user.User, error) { + users, err := user.ParsePasswdFileFilter(passwdPath, filter) + if err != nil { + return user.User{}, err + } + if len(users) != 1 { + return user.User{}, errors.Errorf("expected exactly 1 user matched '%d'", len(users)) + } + return users[0], nil +} + +func getGroup(groupPath string, filter func(user.Group) bool) (user.Group, error) { + groups, err := user.ParseGroupFileFilter(groupPath, filter) + if err != nil { + return user.Group{}, err + } + if len(groups) != 1 { + return user.Group{}, errors.Errorf("expected exactly 1 group matched '%d'", len(groups)) + } + return groups[0], nil +} + +func GetAllUserInfo(containerID string, process *oci.Process) (IDName, []IDName, string, error) { + rootPath := filepath.Join(guestpath.LCOWRootPrefixInUVM, containerID, guestpath.RootfsPath) + passwdPath := filepath.Join(rootPath, "/etc/passwd") + groupPath := filepath.Join(rootPath, "/etc/group") + + if process == nil { + return IDName{}, nil, "", errors.New("spec.Process is nil") + } + + // this default value is used in the Linux kernel if no umask is specified + umask := "0022" + if process.User.Umask != nil { + umask = fmt.Sprintf("%04o", *process.User.Umask) + } + + if process.User.Username != "" { + uid, gid, err := specGuest.ParseUserStr(rootPath, process.User.Username) + if err == nil { + userIDName := IDName{ID: strconv.FormatUint(uint64(uid), 10)} + groupIDName := IDName{ID: strconv.FormatUint(uint64(gid), 10)} + return userIDName, []IDName{groupIDName}, umask, nil + } + log.G(context.Background()).WithError(err).Warn("failed to parse user str, fallback to lookup") + } + + // fallback UID/GID lookup + uid := process.User.UID + userIDName := IDName{ID: strconv.FormatUint(uint64(uid), 10), Name: ""} + if _, err := os.Stat(passwdPath); err == nil { + userInfo, err := getUser(passwdPath, func(user user.User) bool { + return uint32(user.Uid) == uid + }) + + if err != nil { + return userIDName, nil, "", err + } + + userIDName.Name = userInfo.Name + } + + gid := process.User.GID + groupIDName := IDName{ID: strconv.FormatUint(uint64(gid), 10), Name: ""} + + checkGroup := true + if _, err := os.Stat(groupPath); err == nil { + groupInfo, err := getGroup(groupPath, func(group user.Group) bool { + return uint32(group.Gid) == gid + }) + + if err != nil { + return userIDName, nil, "", err + } + groupIDName.Name = groupInfo.Name + } else { + checkGroup = false + } + + groupIDNames := []IDName{groupIDName} + additionalGIDs := process.User.AdditionalGids + if len(additionalGIDs) > 0 { + for _, gid := range additionalGIDs { + groupIDName = IDName{ID: strconv.FormatUint(uint64(gid), 10), Name: ""} + if checkGroup { + groupInfo, err := getGroup(groupPath, func(group user.Group) bool { + return uint32(group.Gid) == gid + }) + if err != nil { + return userIDName, nil, "", err + } + groupIDName.Name = groupInfo.Name + } + groupIDNames = append(groupIDNames, groupIDName) + } + } + + return userIDName, groupIDNames, umask, nil +} diff --git a/pkg/securitypolicy/securitypolicy_test.go b/pkg/securitypolicy/securitypolicy_test.go index cf5046b878..378629d706 100644 --- a/pkg/securitypolicy/securitypolicy_test.go +++ b/pkg/securitypolicy/securitypolicy_test.go @@ -1,6 +1,3 @@ -//go:build linux -// +build linux - package securitypolicy import ( diff --git a/pkg/securitypolicy/securitypolicy_windows.go b/pkg/securitypolicy/securitypolicy_windows.go new file mode 100644 index 0000000000..98b4554342 --- /dev/null +++ b/pkg/securitypolicy/securitypolicy_windows.go @@ -0,0 +1,30 @@ +//go:build windows +// +build windows + +package securitypolicy + +import oci "github.com/opencontainers/runtime-spec/specs-go" + +// This is being used by StandEnforcer and is a no-op for windows. +// substituteUVMPath substitutes mount prefix to an appropriate path inside +// UVM. At policy generation time, it's impossible to tell what the sandboxID +// will be, so the prefix substitution needs to happen during runtime. +func substituteUVMPath(sandboxID string, m mountInternal) mountInternal { + //no-op for windows + _ = sandboxID + return m +} + +// SandboxMountsDir returns sandbox mounts directory inside UVM/host. +func SandboxMountsDir(sandboxID string) string { + return "" +} + +// HugePagesMountsDir returns hugepages mounts directory inside UVM. +func HugePagesMountsDir(sandboxID string) string { + return "" +} + +func GetAllUserInfo(containerID string, process *oci.Process) (IDName, []IDName, string, error) { + return IDName{}, []IDName{}, "", nil +} diff --git a/pkg/securitypolicy/securitypolicyenforcer.go b/pkg/securitypolicy/securitypolicyenforcer.go index 5806eb7544..d3f89a4577 100644 --- a/pkg/securitypolicy/securitypolicyenforcer.go +++ b/pkg/securitypolicy/securitypolicyenforcer.go @@ -1,6 +1,3 @@ -//go:build linux -// +build linux - package securitypolicy import ( @@ -10,21 +7,46 @@ import ( "fmt" "regexp" "strconv" - "strings" "sync" "syscall" + "github.com/Microsoft/hcsshim/internal/protocol/guestrequest" oci "github.com/opencontainers/runtime-spec/specs-go" "github.com/pkg/errors" - - specGuest "github.com/Microsoft/hcsshim/internal/guest/spec" - "github.com/Microsoft/hcsshim/internal/guestpath" ) -type createEnforcerFunc func(base64EncodedPolicy string, criMounts, criPrivilegedMounts []oci.Mount, maxErrorMessageLength int) (SecurityPolicyEnforcer, error) +type createEnforcerFunc func(base64EncodedPolicy string, criMounts, criPrivilegedMounts []oci.Mount, maxErrorMessageLength int, osType string) (SecurityPolicyEnforcer, error) type EnvList []string +type ExecOptions struct { + User *IDName // for linux, optional: nil means "not set". for windows, only name is set + Groups []IDName // optional: empty slice or nil + Umask string // optional: "" means unspecified + Capabilities *oci.LinuxCapabilities // optional: nil means "none" + NoNewPrivileges *bool // optional: nil means "not set" +} + +type CreateContainerOptions struct { + SandboxID string + Privileged *bool + NoNewPrivileges *bool + Groups []IDName + Umask string + Capabilities *oci.LinuxCapabilities + SeccompProfileSHA256 string +} + +type SignalContainerOptions struct { + IsInitProcess bool + // One of these will be set depending on platform + LinuxSignal syscall.Signal + WindowsSignal guestrequest.SignalValueWCOW + + LinuxStartupArgs []string + WindowsCommand string +} + const ( openDoorEnforcer = "open_door" standardEnforcer = "standard" @@ -61,6 +83,16 @@ type SecurityPolicyEnforcer interface { capabilities *oci.LinuxCapabilities, seccompProfileSHA256 string, ) (EnvList, *oci.LinuxCapabilities, bool, error) + EnforceCreateContainerPolicyV2( + ctx context.Context, + containerID string, + argList []string, + envList []string, + workingDir string, + mounts []oci.Mount, + user IDName, + opts *CreateContainerOptions, + ) (EnvList, *oci.LinuxCapabilities, bool, error) ExtendDefaultMounts([]oci.Mount) error EncodedSecurityPolicy() string EnforceExecInContainerPolicy( @@ -75,9 +107,18 @@ type SecurityPolicyEnforcer interface { umask string, capabilities *oci.LinuxCapabilities, ) (EnvList, *oci.LinuxCapabilities, bool, error) + EnforceExecInContainerPolicyV2( + ctx context.Context, + containerID string, + argList []string, + envList []string, + workingDir string, + opts *ExecOptions, + ) (EnvList, *oci.LinuxCapabilities, bool, error) EnforceExecExternalProcessPolicy(ctx context.Context, argList []string, envList []string, workingDir string) (EnvList, bool, error) EnforceShutdownContainerPolicy(ctx context.Context, containerID string) error EnforceSignalContainerProcessPolicy(ctx context.Context, containerID string, signal syscall.Signal, isInitProcess bool, startupArgList []string) error + EnforceSignalContainerProcessPolicyV2(ctx context.Context, containerID string, opts *SignalContainerOptions) error EnforcePlan9MountPolicy(ctx context.Context, target string) (err error) EnforcePlan9UnmountPolicy(ctx context.Context, target string) (err error) EnforceGetPropertiesPolicy(ctx context.Context) error @@ -86,15 +127,19 @@ type SecurityPolicyEnforcer interface { LoadFragment(ctx context.Context, issuer string, feed string, code string) error EnforceScratchMountPolicy(ctx context.Context, scratchPath string, encrypted bool) (err error) EnforceScratchUnmountPolicy(ctx context.Context, scratchPath string) (err error) + EnforceVerifiedCIMsPolicy(ctx context.Context, containerID string, layerHashes []string) (err error) GetUserInfo(containerID string, spec *oci.Process) (IDName, []IDName, string, error) } +//nolint type stringSet map[string]struct{} +//nolint func (s stringSet) add(item string) { s[item] = struct{}{} } +//nolint func (s stringSet) contains(item string) bool { _, contains := s[item] return contains @@ -122,7 +167,7 @@ func newSecurityPolicyFromBase64JSON(base64EncodedPolicy string) (*SecurityPolic // createAllowAllEnforcer creates and returns OpenDoorSecurityPolicyEnforcer instance. // Both AllowAll and Containers cannot be set at the same time. -func createOpenDoorEnforcer(base64EncodedPolicy string, _, _ []oci.Mount, _ int) (SecurityPolicyEnforcer, error) { +func createOpenDoorEnforcer(base64EncodedPolicy string, _, _ []oci.Mount, _ int, _ string) (SecurityPolicyEnforcer, error) { // This covers the case when an "open_door" enforcer was requested, but no // actual security policy was passed. This can happen e.g. when a container // scratch is created for the first time. @@ -172,6 +217,7 @@ func createStandardEnforcer( criMounts, criPrivilegedMounts []oci.Mount, maxErrorMessageLength int, + osType string, ) (SecurityPolicyEnforcer, error) { securityPolicy, err := newSecurityPolicyFromBase64JSON(base64EncodedPolicy) if err != nil { @@ -179,7 +225,7 @@ func createStandardEnforcer( } if securityPolicy.AllowAll { - return createOpenDoorEnforcer(base64EncodedPolicy, criMounts, criPrivilegedMounts, maxErrorMessageLength) + return createOpenDoorEnforcer(base64EncodedPolicy, criMounts, criPrivilegedMounts, maxErrorMessageLength, osType) } containers, err := securityPolicy.Containers.toInternal() @@ -209,6 +255,7 @@ func CreateSecurityPolicyEnforcer( criMounts, criPrivilegedMounts []oci.Mount, maxErrorMessageLength int, + osType string, ) (SecurityPolicyEnforcer, error) { if enforcer == "" { enforcer = defaultEnforcer @@ -219,7 +266,7 @@ func CreateSecurityPolicyEnforcer( if createEnforcer, ok := registeredEnforcers[enforcer]; !ok { return nil, fmt.Errorf("unknown enforcer: %q", enforcer) } else { - return createEnforcer(base64EncodedPolicy, criMounts, criPrivilegedMounts, maxErrorMessageLength) + return createEnforcer(base64EncodedPolicy, criMounts, criPrivilegedMounts, maxErrorMessageLength, osType) } } @@ -512,12 +559,36 @@ func (pe *StandardSecurityPolicyEnforcer) EnforceCreateContainerPolicy( return envList, caps, true, nil } +func (*StandardSecurityPolicyEnforcer) EnforceCreateContainerPolicyV2( + ctx context.Context, + containerID string, + argList []string, + envList []string, + workingDir string, + mounts []oci.Mount, + user IDName, + opts *CreateContainerOptions, +) (EnvList, *oci.LinuxCapabilities, bool, error) { + return envList, opts.Capabilities, true, nil +} + // Stub. We are deprecating the standard enforcer. Newly added enforcement // points are simply allowed. func (*StandardSecurityPolicyEnforcer) EnforceExecInContainerPolicy(_ context.Context, _ string, _ []string, envList []string, _ string, _ bool, _ IDName, _ []IDName, _ string, caps *oci.LinuxCapabilities) (EnvList, *oci.LinuxCapabilities, bool, error) { return envList, caps, true, nil } +func (*StandardSecurityPolicyEnforcer) EnforceExecInContainerPolicyV2( + ctx context.Context, + containerID string, + argList []string, + envList []string, + workingDir string, + opts *ExecOptions, +) (EnvList, *oci.LinuxCapabilities, bool, error) { + return envList, opts.Capabilities, true, nil +} + // Stub. We are deprecating the standard enforcer. Newly added enforcement // points are simply allowed. func (*StandardSecurityPolicyEnforcer) EnforceExecExternalProcessPolicy(_ context.Context, _ []string, envList []string, _ string) (EnvList, bool, error) { @@ -536,6 +607,10 @@ func (*StandardSecurityPolicyEnforcer) EnforceSignalContainerProcessPolicy(conte return nil } +func (*StandardSecurityPolicyEnforcer) EnforceSignalContainerProcessPolicyV2(ctx context.Context, containerID string, opts *SignalContainerOptions) error { + return nil +} + // Stub. We are deprecating the standard enforcer. Newly added enforcement // points are simply allowed. func (*StandardSecurityPolicyEnforcer) EnforcePlan9MountPolicy(context.Context, string) error { @@ -590,6 +665,10 @@ func (StandardSecurityPolicyEnforcer) EnforceScratchUnmountPolicy(context.Contex return nil } +func (StandardSecurityPolicyEnforcer) EnforceVerifiedCIMsPolicy(ctx context.Context, containerID string, layerHashes []string) error { + return nil +} + // Stub. We are deprecating the standard enforcer. func (StandardSecurityPolicyEnforcer) GetUserInfo(containerID string, spec *oci.Process) (IDName, []IDName, string, error) { return IDName{}, nil, "", nil @@ -856,18 +935,6 @@ func (c *securityPolicyContainer) matchMount(sandboxID string, m oci.Mount) (err return fmt.Errorf("mount is not allowed by policy: %+v", m) } -// substituteUVMPath substitutes mount prefix to an appropriate path inside -// UVM. At policy generation time, it's impossible to tell what the sandboxID -// will be, so the prefix substitution needs to happen during runtime. -func substituteUVMPath(sandboxID string, m mountInternal) mountInternal { - if strings.HasPrefix(m.Source, guestpath.SandboxMountPrefix) { - m.Source = specGuest.SandboxMountSource(sandboxID, m.Source) - } else if strings.HasPrefix(m.Source, guestpath.HugePagesMountPrefix) { - m.Source = specGuest.HugePagesMountSource(sandboxID, m.Source) - } - return m -} - func stringSlicesEqual(slice1, slice2 []string) bool { if len(slice1) != len(slice2) { return false @@ -911,10 +978,34 @@ func (OpenDoorSecurityPolicyEnforcer) EnforceCreateContainerPolicy(_ context.Con return envList, caps, true, nil } +func (OpenDoorSecurityPolicyEnforcer) EnforceCreateContainerPolicyV2( + ctx context.Context, + containerID string, + argList []string, + envList []string, + workingDir string, + mounts []oci.Mount, + user IDName, + opts *CreateContainerOptions, +) (EnvList, *oci.LinuxCapabilities, bool, error) { + return envList, opts.Capabilities, true, nil +} + func (OpenDoorSecurityPolicyEnforcer) EnforceExecInContainerPolicy(_ context.Context, _ string, _ []string, envList []string, _ string, _ bool, _ IDName, _ []IDName, _ string, caps *oci.LinuxCapabilities) (EnvList, *oci.LinuxCapabilities, bool, error) { return envList, caps, true, nil } +func (OpenDoorSecurityPolicyEnforcer) EnforceExecInContainerPolicyV2( + ctx context.Context, + containerID string, + argList []string, + envList []string, + workingDir string, + opts *ExecOptions, +) (EnvList, *oci.LinuxCapabilities, bool, error) { + return envList, opts.Capabilities, true, nil +} + func (OpenDoorSecurityPolicyEnforcer) EnforceExecExternalProcessPolicy(_ context.Context, _ []string, envList []string, _ string) (EnvList, bool, error) { return envList, true, nil } @@ -927,6 +1018,10 @@ func (*OpenDoorSecurityPolicyEnforcer) EnforceSignalContainerProcessPolicy(conte return nil } +func (*OpenDoorSecurityPolicyEnforcer) EnforceSignalContainerProcessPolicyV2(ctx context.Context, containerID string, opts *SignalContainerOptions) error { + return nil +} + func (*OpenDoorSecurityPolicyEnforcer) EnforcePlan9MountPolicy(context.Context, string) error { return nil } @@ -967,6 +1062,10 @@ func (OpenDoorSecurityPolicyEnforcer) EnforceScratchUnmountPolicy(context.Contex return nil } +func (OpenDoorSecurityPolicyEnforcer) EnforceVerifiedCIMsPolicy(ctx context.Context, containerID string, layerHashes []string) error { + return nil +} + func (OpenDoorSecurityPolicyEnforcer) GetUserInfo(containerID string, spec *oci.Process) (IDName, []IDName, string, error) { return IDName{}, nil, "", nil } @@ -997,10 +1096,34 @@ func (ClosedDoorSecurityPolicyEnforcer) EnforceCreateContainerPolicy(context.Con return nil, nil, false, errors.New("running commands is denied by policy") } +func (ClosedDoorSecurityPolicyEnforcer) EnforceCreateContainerPolicyV2( + ctx context.Context, + containerID string, + argList []string, + envList []string, + workingDir string, + mounts []oci.Mount, + user IDName, + opts *CreateContainerOptions, +) (EnvList, *oci.LinuxCapabilities, bool, error) { + return nil, nil, false, errors.New("running commands is denied by policy") +} + func (ClosedDoorSecurityPolicyEnforcer) EnforceExecInContainerPolicy(context.Context, string, []string, []string, string, bool, IDName, []IDName, string, *oci.LinuxCapabilities) (EnvList, *oci.LinuxCapabilities, bool, error) { return nil, nil, false, errors.New("starting additional processes in a container is denied by policy") } +func (ClosedDoorSecurityPolicyEnforcer) EnforceExecInContainerPolicyV2( + ctx context.Context, + containerID string, + argList []string, + envList []string, + workingDir string, + opts *ExecOptions, +) (EnvList, *oci.LinuxCapabilities, bool, error) { + return nil, nil, false, errors.New("starting additional processes in a container is denied by policy") +} + func (ClosedDoorSecurityPolicyEnforcer) EnforceExecExternalProcessPolicy(context.Context, []string, []string, string) (EnvList, bool, error) { return nil, false, errors.New("starting additional processes in uvm is denied by policy") } @@ -1013,6 +1136,10 @@ func (*ClosedDoorSecurityPolicyEnforcer) EnforceSignalContainerProcessPolicy(con return errors.New("signalling container processes is denied by policy") } +func (*ClosedDoorSecurityPolicyEnforcer) EnforceSignalContainerProcessPolicyV2(ctx context.Context, containerID string, opts *SignalContainerOptions) error { + return errors.New("signalling container processes is denied by policy") +} + func (*ClosedDoorSecurityPolicyEnforcer) EnforcePlan9MountPolicy(context.Context, string) error { return errors.New("mounting is denied by policy") } @@ -1053,6 +1180,10 @@ func (ClosedDoorSecurityPolicyEnforcer) EnforceScratchUnmountPolicy(context.Cont return errors.New("unmounting scratch is denied by the policy") } +func (ClosedDoorSecurityPolicyEnforcer) EnforceVerifiedCIMsPolicy(ctx context.Context, containerID string, layerHashes []string) error { + return nil +} + func (ClosedDoorSecurityPolicyEnforcer) GetUserInfo(containerID string, spec *oci.Process) (IDName, []IDName, string, error) { return IDName{}, nil, "", nil } diff --git a/pkg/securitypolicy/securitypolicyenforcer_rego.go b/pkg/securitypolicy/securitypolicyenforcer_rego.go index 52a0cbf571..94ddff7126 100644 --- a/pkg/securitypolicy/securitypolicyenforcer_rego.go +++ b/pkg/securitypolicy/securitypolicyenforcer_rego.go @@ -1,5 +1,5 @@ -//go:build linux && rego -// +build linux,rego +//go:build rego +// +build rego package securitypolicy @@ -9,20 +9,15 @@ import ( "encoding/base64" "encoding/json" "fmt" - "os" - "path/filepath" "strconv" "strings" "syscall" - "github.com/opencontainers/runc/libcontainer/user" - oci "github.com/opencontainers/runtime-spec/specs-go" - "github.com/pkg/errors" - - specGuest "github.com/Microsoft/hcsshim/internal/guest/spec" "github.com/Microsoft/hcsshim/internal/guestpath" "github.com/Microsoft/hcsshim/internal/log" rpi "github.com/Microsoft/hcsshim/internal/regopolicyinterpreter" + oci "github.com/opencontainers/runtime-spec/specs-go" + "github.com/pkg/errors" ) const regoEnforcerName = "rego" @@ -59,6 +54,8 @@ type regoEnforcer struct { stdio map[string]bool // Maximum error message length maxErrorMessageLength int + // OS type + osType string } var _ SecurityPolicyEnforcer = (*regoEnforcer)(nil) @@ -109,6 +106,7 @@ func createRegoEnforcer(base64EncodedPolicy string, defaultMounts []oci.Mount, privilegedMounts []oci.Mount, maxErrorMessageLength int, + osType string, ) (SecurityPolicyEnforcer, error) { // base64 decode the incoming policy string // It will either be (legacy) JSON or Rego. @@ -123,7 +121,7 @@ func createRegoEnforcer(base64EncodedPolicy string, err = json.Unmarshal(rawPolicy, securityPolicy) if err == nil { if securityPolicy.AllowAll { - return createOpenDoorEnforcer(base64EncodedPolicy, defaultMounts, privilegedMounts, maxErrorMessageLength) + return createOpenDoorEnforcer(base64EncodedPolicy, defaultMounts, privilegedMounts, maxErrorMessageLength, osType) } containers := make([]*Container, securityPolicy.Containers.Length) @@ -165,7 +163,7 @@ func createRegoEnforcer(base64EncodedPolicy string, code = string(rawPolicy) } - regoPolicy, err := newRegoPolicy(code, defaultMounts, privilegedMounts) + regoPolicy, err := newRegoPolicy(code, defaultMounts, privilegedMounts, osType) if err != nil { return nil, fmt.Errorf("error creating Rego policy: %w", err) } @@ -178,9 +176,10 @@ func (policy *regoEnforcer) enableLogging(path string, logLevel rpi.LogLevel) { policy.rego.EnableLogging(path, logLevel) } -func newRegoPolicy(code string, defaultMounts []oci.Mount, privilegedMounts []oci.Mount) (policy *regoEnforcer, err error) { +func newRegoPolicy(code string, defaultMounts []oci.Mount, privilegedMounts []oci.Mount, osType string) (policy *regoEnforcer, err error) { policy = new(regoEnforcer) + policy.osType = osType policy.defaultMounts = make([]oci.Mount, len(defaultMounts)) copy(policy.defaultMounts, defaultMounts) @@ -197,6 +196,7 @@ func newRegoPolicy(code string, defaultMounts []oci.Mount, privilegedMounts []oc } policy.rego, err = rpi.NewRegoPolicyInterpreter(code, data) + policy.rego.UpdateOSType(osType) if err != nil { return nil, err } @@ -711,25 +711,70 @@ func (policy *regoEnforcer) EnforceCreateContainerPolicy( capsToKeep *oci.LinuxCapabilities, stdioAccessAllowed bool, err error) { - if capabilities == nil { + opts := &CreateContainerOptions{ + SandboxID: sandboxID, + Privileged: &privileged, + NoNewPrivileges: &noNewPrivileges, + Groups: groups, + Umask: umask, + Capabilities: capabilities, + SeccompProfileSHA256: seccompProfileSHA256, + } + return policy.EnforceCreateContainerPolicyV2(ctx, containerID, argList, envList, workingDir, mounts, user, opts) +} + +func (policy *regoEnforcer) EnforceCreateContainerPolicyV2( + ctx context.Context, + containerID string, + argList []string, + envList []string, + workingDir string, + mounts []oci.Mount, + user IDName, + opts *CreateContainerOptions, +) (envToKeep EnvList, + capsToKeep *oci.LinuxCapabilities, + stdioAccessAllowed bool, + err error) { + + if policy.osType == "linux" && opts.Capabilities == nil { return nil, nil, false, errors.New(capabilitiesNilError) } - input := inputData{ - "containerID": containerID, - "argList": argList, - "envList": envList, - "workingDir": workingDir, - "sandboxDir": specGuest.SandboxMountsDir(sandboxID), - "hugePagesDir": specGuest.HugePagesMountsDir(sandboxID), - "mounts": appendMountData([]interface{}{}, mounts), - "privileged": privileged, - "noNewPrivileges": noNewPrivileges, - "user": user.toInput(), - "groups": groupsToInputs(groups), - "umask": umask, - "capabilities": mapifyCapabilities(capabilities), - "seccompProfileSHA256": seccompProfileSHA256, + var input inputData + + switch policy.osType { + case "linux": + input = inputData{ + "containerID": containerID, + "argList": argList, + "envList": envList, + "workingDir": workingDir, + "sandboxDir": SandboxMountsDir(opts.SandboxID), + "hugePagesDir": HugePagesMountsDir(opts.SandboxID), + "mounts": appendMountData([]interface{}{}, mounts), + "privileged": opts.Privileged, + "noNewPrivileges": opts.NoNewPrivileges, + "user": user.toInput(), + "groups": groupsToInputs(opts.Groups), + "umask": opts.Umask, + "capabilities": mapifyCapabilities(opts.Capabilities), + "seccompProfileSHA256": opts.SeccompProfileSHA256, + } + case "windows": + if envList == nil { + envList = []string{} + } + input = inputData{ + "containerID": containerID, + "argList": argList, + "envList": envList, + "workingDir": workingDir, + "privileged": true, + "user": user.Name, + } + default: + return nil, nil, false, errors.Errorf("unsupported OS value in options: %q", policy.osType) } results, err := policy.enforce(ctx, "create_container", input) @@ -742,9 +787,11 @@ func (policy *regoEnforcer) EnforceCreateContainerPolicy( return nil, nil, false, err } - capsToKeep, err = getCapsToKeep(capabilities, results) - if err != nil { - return nil, nil, false, err + if policy.osType == "linux" { + capsToKeep, err = getCapsToKeep(opts.Capabilities, results) + if err != nil { + return nil, nil, false, err + } } stdioAccessAllowed, err = results.Bool("allow_stdio_access") @@ -807,20 +854,57 @@ func (policy *regoEnforcer) EnforceExecInContainerPolicy( capsToKeep *oci.LinuxCapabilities, stdioAccessAllowed bool, err error) { - if capabilities == nil { + opts := &ExecOptions{ + User: &user, + Groups: groups, + Umask: umask, + Capabilities: capabilities, + NoNewPrivileges: &noNewPrivileges, + } + return policy.EnforceExecInContainerPolicyV2(ctx, containerID, argList, envList, workingDir, opts) +} + +func (policy *regoEnforcer) EnforceExecInContainerPolicyV2( + ctx context.Context, + containerID string, + argList []string, + envList []string, + workingDir string, + opts *ExecOptions, +) (envToKeep EnvList, + capsToKeep *oci.LinuxCapabilities, + stdioAccessAllowed bool, + err error) { + + if policy.osType == "linux" && opts.Capabilities == nil { return nil, nil, false, errors.New(capabilitiesNilError) } - input := inputData{ - "containerID": containerID, - "argList": argList, - "envList": envList, - "workingDir": workingDir, - "noNewPrivileges": noNewPrivileges, - "user": user.toInput(), - "groups": groupsToInputs(groups), - "umask": umask, - "capabilities": mapifyCapabilities(capabilities), + var input inputData + + switch policy.osType { + case "linux": + input = inputData{ + "containerID": containerID, + "argList": argList, + "envList": envList, + "workingDir": workingDir, + "noNewPrivileges": opts.NoNewPrivileges, + "user": opts.User.toInput(), + "groups": groupsToInputs(opts.Groups), + "umask": opts.Umask, + "capabilities": mapifyCapabilities(opts.Capabilities), + } + case "windows": + input = inputData{ + "containerID": containerID, + "argList": argList, + "envList": envList, + "workingDir": workingDir, + "user": opts.User.Name, + } + default: + return nil, nil, false, errors.Errorf("unsupported OS value in options: %q", policy.osType) } results, err := policy.enforce(ctx, "exec_in_container", input) @@ -833,11 +917,12 @@ func (policy *regoEnforcer) EnforceExecInContainerPolicy( return nil, nil, false, err } - capsToKeep, err = getCapsToKeep(capabilities, results) - if err != nil { - return nil, nil, false, err + if policy.osType == "linux" { + capsToKeep, err = getCapsToKeep(opts.Capabilities, results) + if err != nil { + return nil, nil, false, err + } } - return envToKeep, capsToKeep, policy.stdio[containerID], nil } @@ -887,6 +972,32 @@ func (policy *regoEnforcer) EnforceSignalContainerProcessPolicy(ctx context.Cont return err } +func (policy *regoEnforcer) EnforceSignalContainerProcessPolicyV2(ctx context.Context, containerID string, opts *SignalContainerOptions) error { + var input inputData + + switch policy.osType { + case "linux": + input = inputData{ + "containerID": containerID, + "signal": opts.LinuxSignal, + "isInitProcess": opts.IsInitProcess, + "argList": opts.LinuxStartupArgs, + } + case "windows": + input = inputData{ + "containerID": containerID, + "signal": opts.WindowsSignal, + "isInitProcess": opts.IsInitProcess, + "cmdLine": opts.WindowsCommand, + } + default: + return errors.Errorf("unsupported OS value in options: %q", policy.osType) + } + + _, err := policy.enforce(ctx, "signal_container_process", input) + return err +} + func (policy *regoEnforcer) EnforcePlan9MountPolicy(ctx context.Context, target string) error { mountPathPrefix := strings.Replace(guestpath.LCOWMountPathPrefixFmt, "%d", "[0-9]+", 1) input := inputData{ @@ -992,102 +1103,17 @@ func (policy *regoEnforcer) EnforceScratchUnmountPolicy(ctx context.Context, scr return nil } -func getUser(passwdPath string, filter func(user.User) bool) (user.User, error) { - users, err := user.ParsePasswdFileFilter(passwdPath, filter) - if err != nil { - return user.User{}, err - } - if len(users) != 1 { - return user.User{}, errors.Errorf("expected exactly 1 user matched '%d'", len(users)) +func (policy *regoEnforcer) EnforceVerifiedCIMsPolicy(ctx context.Context, containerID string, layerHashes []string) error { + log.G(ctx).Tracef("Enforcing verified cims in securitypolicy pkg %+v", layerHashes) + input := inputData{ + "containerID": containerID, + "layerHashes": layerHashes, } - return users[0], nil -} -func getGroup(groupPath string, filter func(user.Group) bool) (user.Group, error) { - groups, err := user.ParseGroupFileFilter(groupPath, filter) - if err != nil { - return user.Group{}, err - } - if len(groups) != 1 { - return user.Group{}, errors.Errorf("expected exactly 1 group matched '%d'", len(groups)) - } - return groups[0], nil + _, err := policy.enforce(ctx, "mount_cims", input) + return err } func (policy *regoEnforcer) GetUserInfo(containerID string, process *oci.Process) (IDName, []IDName, string, error) { - rootPath := filepath.Join(guestpath.LCOWRootPrefixInUVM, containerID, guestpath.RootfsPath) - passwdPath := filepath.Join(rootPath, "/etc/passwd") - groupPath := filepath.Join(rootPath, "/etc/group") - - if process == nil { - return IDName{}, nil, "", errors.New("spec.Process is nil") - } - - // this default value is used in the Linux kernel if no umask is specified - umask := "0022" - if process.User.Umask != nil { - umask = fmt.Sprintf("%04o", *process.User.Umask) - } - - if process.User.Username != "" { - uid, gid, err := specGuest.ParseUserStr(rootPath, process.User.Username) - if err == nil { - userIDName := IDName{ID: strconv.FormatUint(uint64(uid), 10)} - groupIDName := IDName{ID: strconv.FormatUint(uint64(gid), 10)} - return userIDName, []IDName{groupIDName}, umask, nil - } - log.G(context.Background()).WithError(err).Warn("failed to parse user str, fallback to lookup") - } - - // fallback UID/GID lookup - uid := process.User.UID - userIDName := IDName{ID: strconv.FormatUint(uint64(uid), 10), Name: ""} - if _, err := os.Stat(passwdPath); err == nil { - userInfo, err := getUser(passwdPath, func(user user.User) bool { - return uint32(user.Uid) == uid - }) - - if err != nil { - return userIDName, nil, "", err - } - - userIDName.Name = userInfo.Name - } - - gid := process.User.GID - groupIDName := IDName{ID: strconv.FormatUint(uint64(gid), 10), Name: ""} - - checkGroup := true - if _, err := os.Stat(groupPath); err == nil { - groupInfo, err := getGroup(groupPath, func(group user.Group) bool { - return uint32(group.Gid) == gid - }) - - if err != nil { - return userIDName, nil, "", err - } - groupIDName.Name = groupInfo.Name - } else { - checkGroup = false - } - - groupIDNames := []IDName{groupIDName} - additionalGIDs := process.User.AdditionalGids - if len(additionalGIDs) > 0 { - for _, gid := range additionalGIDs { - groupIDName = IDName{ID: strconv.FormatUint(uint64(gid), 10), Name: ""} - if checkGroup { - groupInfo, err := getGroup(groupPath, func(group user.Group) bool { - return uint32(group.Gid) == gid - }) - if err != nil { - return userIDName, nil, "", err - } - groupIDName.Name = groupInfo.Name - } - groupIDNames = append(groupIDNames, groupIDName) - } - } - - return userIDName, groupIDNames, umask, nil + return GetAllUserInfo(containerID, process) } diff --git a/pkg/securitypolicy/version_api b/pkg/securitypolicy/version_api index 2774f8587f..142464bf22 100644 --- a/pkg/securitypolicy/version_api +++ b/pkg/securitypolicy/version_api @@ -1 +1 @@ -0.10.0 \ No newline at end of file +0.11.0 \ No newline at end of file diff --git a/pkg/securitypolicy/version_framework b/pkg/securitypolicy/version_framework index 9325c3ccda..60a2d3e96c 100644 --- a/pkg/securitypolicy/version_framework +++ b/pkg/securitypolicy/version_framework @@ -1 +1 @@ -0.3.0 \ No newline at end of file +0.4.0 \ No newline at end of file diff --git a/test/gcs/container_test.go b/test/gcs/container_test.go index 6a68e55c33..2f29f3979e 100644 --- a/test/gcs/container_test.go +++ b/test/gcs/container_test.go @@ -12,7 +12,7 @@ import ( "github.com/containerd/containerd/oci" "golang.org/x/sync/errgroup" - "github.com/Microsoft/hcsshim/internal/guest/gcserr" + "github.com/Microsoft/hcsshim/internal/bridgeutils/gcserr" "github.com/Microsoft/hcsshim/internal/guest/stdio" testoci "github.com/Microsoft/hcsshim/test/internal/oci" diff --git a/vendor/github.com/opencontainers/runc/libcontainer/user/lookup_deprecated.go b/vendor/github.com/opencontainers/runc/libcontainer/user/lookup_deprecated.go deleted file mode 100644 index c6cd443455..0000000000 --- a/vendor/github.com/opencontainers/runc/libcontainer/user/lookup_deprecated.go +++ /dev/null @@ -1,81 +0,0 @@ -package user - -import ( - "io" - - "github.com/moby/sys/user" -) - -// LookupUser looks up a user by their username in /etc/passwd. If the user -// cannot be found (or there is no /etc/passwd file on the filesystem), then -// LookupUser returns an error. -func LookupUser(username string) (user.User, error) { - return user.LookupUser(username) -} - -// LookupUid looks up a user by their user id in /etc/passwd. If the user cannot -// be found (or there is no /etc/passwd file on the filesystem), then LookupId -// returns an error. -func LookupUid(uid int) (user.User, error) { //nolint:revive // ignore var-naming: func LookupUid should be LookupUID - return user.LookupUid(uid) -} - -// LookupGroup looks up a group by its name in /etc/group. If the group cannot -// be found (or there is no /etc/group file on the filesystem), then LookupGroup -// returns an error. -func LookupGroup(groupname string) (user.Group, error) { - return user.LookupGroup(groupname) -} - -// LookupGid looks up a group by its group id in /etc/group. If the group cannot -// be found (or there is no /etc/group file on the filesystem), then LookupGid -// returns an error. -func LookupGid(gid int) (user.Group, error) { - return user.LookupGid(gid) -} - -func GetPasswdPath() (string, error) { - return user.GetPasswdPath() -} - -func GetPasswd() (io.ReadCloser, error) { - return user.GetPasswd() -} - -func GetGroupPath() (string, error) { - return user.GetGroupPath() -} - -func GetGroup() (io.ReadCloser, error) { - return user.GetGroup() -} - -// CurrentUser looks up the current user by their user id in /etc/passwd. If the -// user cannot be found (or there is no /etc/passwd file on the filesystem), -// then CurrentUser returns an error. -func CurrentUser() (user.User, error) { - return user.CurrentUser() -} - -// CurrentGroup looks up the current user's group by their primary group id's -// entry in /etc/passwd. If the group cannot be found (or there is no -// /etc/group file on the filesystem), then CurrentGroup returns an error. -func CurrentGroup() (user.Group, error) { - return user.CurrentGroup() -} - -func CurrentUserSubUIDs() ([]user.SubID, error) { - return user.CurrentUserSubUIDs() -} - -func CurrentUserSubGIDs() ([]user.SubID, error) { - return user.CurrentUserSubGIDs() -} - -func CurrentProcessUIDMap() ([]user.IDMap, error) { - return user.CurrentProcessUIDMap() -} - -func CurrentProcessGIDMap() ([]user.IDMap, error) { - return user.CurrentProcessGIDMap() -} diff --git a/vendor/github.com/opencontainers/runc/libcontainer/user/user_deprecated.go b/vendor/github.com/opencontainers/runc/libcontainer/user/user_deprecated.go deleted file mode 100644 index 3c29f3d1d8..0000000000 --- a/vendor/github.com/opencontainers/runc/libcontainer/user/user_deprecated.go +++ /dev/null @@ -1,146 +0,0 @@ -// Package user is an alias for [github.com/moby/sys/user]. -// -// Deprecated: use [github.com/moby/sys/user]. -package user - -import ( - "io" - - "github.com/moby/sys/user" -) - -var ( - // ErrNoPasswdEntries is returned if no matching entries were found in /etc/group. - ErrNoPasswdEntries = user.ErrNoPasswdEntries - // ErrNoGroupEntries is returned if no matching entries were found in /etc/passwd. - ErrNoGroupEntries = user.ErrNoGroupEntries - // ErrRange is returned if a UID or GID is outside of the valid range. - ErrRange = user.ErrRange -) - -type ( - User = user.User - - Group = user.Group - - // SubID represents an entry in /etc/sub{u,g}id. - SubID = user.SubID - - // IDMap represents an entry in /proc/PID/{u,g}id_map. - IDMap = user.IDMap - - ExecUser = user.ExecUser -) - -func ParsePasswdFile(path string) ([]user.User, error) { - return user.ParsePasswdFile(path) -} - -func ParsePasswd(passwd io.Reader) ([]user.User, error) { - return user.ParsePasswd(passwd) -} - -func ParsePasswdFileFilter(path string, filter func(user.User) bool) ([]user.User, error) { - return user.ParsePasswdFileFilter(path, filter) -} - -func ParsePasswdFilter(r io.Reader, filter func(user.User) bool) ([]user.User, error) { - return user.ParsePasswdFilter(r, filter) -} - -func ParseGroupFile(path string) ([]user.Group, error) { - return user.ParseGroupFile(path) -} - -func ParseGroup(group io.Reader) ([]user.Group, error) { - return user.ParseGroup(group) -} - -func ParseGroupFileFilter(path string, filter func(user.Group) bool) ([]user.Group, error) { - return user.ParseGroupFileFilter(path, filter) -} - -func ParseGroupFilter(r io.Reader, filter func(user.Group) bool) ([]user.Group, error) { - return user.ParseGroupFilter(r, filter) -} - -// GetExecUserPath is a wrapper for GetExecUser. It reads data from each of the -// given file paths and uses that data as the arguments to GetExecUser. If the -// files cannot be opened for any reason, the error is ignored and a nil -// io.Reader is passed instead. -func GetExecUserPath(userSpec string, defaults *user.ExecUser, passwdPath, groupPath string) (*user.ExecUser, error) { - return user.GetExecUserPath(userSpec, defaults, passwdPath, groupPath) -} - -// GetExecUser parses a user specification string (using the passwd and group -// readers as sources for /etc/passwd and /etc/group data, respectively). In -// the case of blank fields or missing data from the sources, the values in -// defaults is used. -// -// GetExecUser will return an error if a user or group literal could not be -// found in any entry in passwd and group respectively. -// -// Examples of valid user specifications are: -// - "" -// - "user" -// - "uid" -// - "user:group" -// - "uid:gid -// - "user:gid" -// - "uid:group" -// -// It should be noted that if you specify a numeric user or group id, they will -// not be evaluated as usernames (only the metadata will be filled). So attempting -// to parse a user with user.Name = "1337" will produce the user with a UID of -// 1337. -func GetExecUser(userSpec string, defaults *user.ExecUser, passwd, group io.Reader) (*user.ExecUser, error) { - return user.GetExecUser(userSpec, defaults, passwd, group) -} - -// GetAdditionalGroups looks up a list of groups by name or group id -// against the given /etc/group formatted data. If a group name cannot -// be found, an error will be returned. If a group id cannot be found, -// or the given group data is nil, the id will be returned as-is -// provided it is in the legal range. -func GetAdditionalGroups(additionalGroups []string, group io.Reader) ([]int, error) { - return user.GetAdditionalGroups(additionalGroups, group) -} - -// GetAdditionalGroupsPath is a wrapper around GetAdditionalGroups -// that opens the groupPath given and gives it as an argument to -// GetAdditionalGroups. -func GetAdditionalGroupsPath(additionalGroups []string, groupPath string) ([]int, error) { - return user.GetAdditionalGroupsPath(additionalGroups, groupPath) -} - -func ParseSubIDFile(path string) ([]user.SubID, error) { - return user.ParseSubIDFile(path) -} - -func ParseSubID(subid io.Reader) ([]user.SubID, error) { - return user.ParseSubID(subid) -} - -func ParseSubIDFileFilter(path string, filter func(user.SubID) bool) ([]user.SubID, error) { - return user.ParseSubIDFileFilter(path, filter) -} - -func ParseSubIDFilter(r io.Reader, filter func(user.SubID) bool) ([]user.SubID, error) { - return user.ParseSubIDFilter(r, filter) -} - -func ParseIDMapFile(path string) ([]user.IDMap, error) { - return user.ParseIDMapFile(path) -} - -func ParseIDMap(r io.Reader) ([]user.IDMap, error) { - return user.ParseIDMap(r) -} - -func ParseIDMapFileFilter(path string, filter func(user.IDMap) bool) ([]user.IDMap, error) { - return user.ParseIDMapFileFilter(path, filter) -} - -func ParseIDMapFilter(r io.Reader, filter func(user.IDMap) bool) ([]user.IDMap, error) { - return user.ParseIDMapFilter(r, filter) -} diff --git a/vendor/modules.txt b/vendor/modules.txt index 0deb3148a3..9c4e6419a4 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -397,7 +397,6 @@ github.com/opencontainers/image-spec/specs-go/v1 # github.com/opencontainers/runc v1.2.3 ## explicit; go 1.22 github.com/opencontainers/runc/libcontainer/devices -github.com/opencontainers/runc/libcontainer/user # github.com/opencontainers/runtime-spec v1.2.0 ## explicit github.com/opencontainers/runtime-spec/specs-go