From 31e40176da32cf8aa1408ce61de8e4f197a223de Mon Sep 17 00:00:00 2001 From: aiordache Date: Wed, 6 Jan 2021 17:11:03 +0100 Subject: [PATCH] Add GPU support via DeviceRequests Signed-off-by: aiordache --- ecs/convert.go | 2 +- go.mod | 2 +- go.sum | 4 ++-- local/compose/create.go | 25 +++++++++++++++++++++++++ 4 files changed, 29 insertions(+), 4 deletions(-) diff --git a/ecs/convert.go b/ecs/convert.go index 3ba7e8bf6..220f9670f 100644 --- a/ecs/convert.go +++ b/ecs/convert.go @@ -564,7 +564,7 @@ func gpuRequirements(s types.ServiceConfig) int64 { } for _, device := range reservations.Devices { if len(device.Capabilities) == 1 && device.Capabilities[0] == "gpu" { - return int64(device.Count) + return device.Count } } } diff --git a/go.mod b/go.mod index dab8c3d3e..f1473b763 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/aws/aws-sdk-go v1.35.33 github.com/awslabs/goformation/v4 v4.15.6 github.com/buger/goterm v0.0.0-20200322175922-2f3e71b85129 - github.com/compose-spec/compose-go v0.0.0-20201210155915-b5ef325e9175 + github.com/compose-spec/compose-go v0.0.0-20210106202047-687be5e0e320 github.com/containerd/console v1.0.1 github.com/containerd/containerd v1.4.3 github.com/containerd/continuity v0.0.0-20200928162600-f2cc35102c2a // indirect diff --git a/go.sum b/go.sum index 2f3ab0ff1..95aa4cab4 100644 --- a/go.sum +++ b/go.sum @@ -236,8 +236,8 @@ github.com/cloudflare/cfssl v0.0.0-20181213083726-b94e044bb51e/go.mod h1:yMWuSON github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cockroachdb/datadriven v0.0.0-20190809214429-80d97fb3cbaa/go.mod h1:zn76sxSg3SzpJ0PPJaLDCu+Bu0Lg3sKTORVIj19EIF8= github.com/codahale/hdrhistogram v0.0.0-20160425231609-f8ad88b59a58/go.mod h1:sE/e/2PUdi/liOCUjSTXgM1o87ZssimdTWN964YiIeI= -github.com/compose-spec/compose-go v0.0.0-20201210155915-b5ef325e9175 h1:6ZE967wCKnx4h+OIUsjnS113itBlncF3ls/Ia7rKcbc= -github.com/compose-spec/compose-go v0.0.0-20201210155915-b5ef325e9175/go.mod h1:rz7rjxJGA/pWpLdBmDdqymGm2okEDYgBE7yx569xW+I= +github.com/compose-spec/compose-go v0.0.0-20210106202047-687be5e0e320 h1:PjwzjUYqjto8PLdHLtPX2/JtCbYYsKMs1Zof7/h29YA= +github.com/compose-spec/compose-go v0.0.0-20210106202047-687be5e0e320/go.mod h1:rz7rjxJGA/pWpLdBmDdqymGm2okEDYgBE7yx569xW+I= github.com/containerd/cgroups v0.0.0-20190919134610-bf292b21730f/go.mod h1:OApqhQ4XNSNC13gXIwDjhOQxjWa/NxkwZXJ1EvqT0ko= github.com/containerd/cgroups v0.0.0-20200531161412-0dbf7f05ba59/go.mod h1:pA0z1pT8KYB3TCXK/ocprsh7MAkoW8bZVzPdih9snmM= github.com/containerd/cgroups v0.0.0-20200710171044-318312a37340 h1:9atoWyI9RtXFwf7UDbme/6M8Ud0rFrx+Q3ZWgSnsxtw= diff --git a/local/compose/create.go b/local/compose/create.go index b97a8173f..992e0e747 100644 --- a/local/compose/create.go +++ b/local/compose/create.go @@ -165,6 +165,7 @@ func getCreateOptions(p *types.Project, s types.ServiceConfig, number int, inher } bindings := buildContainerBindingOptions(s) + resources := getDeployResources(s) networkMode := getNetworkMode(p, s) hostConfig := container.HostConfig{ AutoRemove: autoRemove, @@ -177,12 +178,36 @@ func getCreateOptions(p *types.Project, s types.ServiceConfig, number int, inher // ShmSize: , TODO Sysctls: s.Sysctls, PortBindings: bindings, + Resources: resources, } networkConfig := buildDefaultNetworkConfig(s, networkMode) return &containerConfig, &hostConfig, networkConfig, nil } +func getDeployResources(s types.ServiceConfig) container.Resources { + resources := container.Resources{} + if s.Deploy == nil { + return resources + } + + reservations := s.Deploy.Resources.Reservations + + if reservations == nil || len(reservations.Devices) == 0 { + return resources + } + + for _, device := range reservations.Devices { + resources.DeviceRequests = append(resources.DeviceRequests, container.DeviceRequest{ + Capabilities: [][]string{device.Capabilities}, + Count: int(device.Count), + DeviceIDs: device.IDs, + Driver: device.Driver, + }) + } + return resources +} + func buildContainerPorts(s types.ServiceConfig) nat.PortSet { ports := nat.PortSet{} for _, p := range s.Ports {