Merge "Extract out common config parsing for ConfigPool"
This commit is contained in:
commit
1fe5fb60c5
|
@ -824,7 +824,11 @@ class ConfigValue(object, metaclass=abc.ABCMeta):
|
|||
return not self.__eq__(other)
|
||||
|
||||
|
||||
class ConfigPool(ConfigValue):
|
||||
class ConfigPool(ConfigValue, metaclass=abc.ABCMeta):
|
||||
'''
|
||||
Base class for a single pool as defined in the configuration file.
|
||||
'''
|
||||
|
||||
def __init__(self):
|
||||
self.labels = {}
|
||||
self.max_servers = math.inf
|
||||
|
@ -837,6 +841,40 @@ class ConfigPool(ConfigValue):
|
|||
self.node_attributes == other.node_attributes)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def getCommonSchemaDict(self):
|
||||
'''
|
||||
Return the schema dict for common pool attributes.
|
||||
|
||||
When a driver validates its own configuration schema, it should call
|
||||
this class method to get and include the common pool attributes in
|
||||
the schema.
|
||||
|
||||
The `labels` attribute, though common, can vary its type across
|
||||
drivers so it is not returned in the schema.
|
||||
'''
|
||||
return {
|
||||
'max-servers': int,
|
||||
'node-attributes': dict,
|
||||
}
|
||||
|
||||
@abc.abstractmethod
|
||||
def load(self, pool_config):
|
||||
'''
|
||||
Load pool config options from the parsed configuration file.
|
||||
|
||||
Subclasses are expected to call the parent method so that common
|
||||
configuration values are loaded properly.
|
||||
|
||||
Although `labels` is a common attribute, each driver may
|
||||
define it differently, so we cannot parse that attribute here.
|
||||
|
||||
:param dict pool_config: A single pool config section from which we
|
||||
will load the values.
|
||||
'''
|
||||
self.max_servers = pool_config.get('max-servers', math.inf)
|
||||
self.node_attributes = pool_config.get('node-attributes')
|
||||
|
||||
|
||||
class DriverConfig(ConfigValue):
|
||||
def __init__(self):
|
||||
|
|
|
@ -45,6 +45,20 @@ class KubernetesPool(ConfigPool):
|
|||
def __repr__(self):
|
||||
return "<KubernetesPool %s>" % self.name
|
||||
|
||||
def load(self, pool_config, full_config):
|
||||
super().load(pool_config)
|
||||
self.name = pool_config['name']
|
||||
self.labels = {}
|
||||
for label in pool_config.get('labels', []):
|
||||
pl = KubernetesLabel()
|
||||
pl.name = label['name']
|
||||
pl.type = label['type']
|
||||
pl.image = label.get('image')
|
||||
pl.image_pull = label.get('image-pull', 'IfNotPresent')
|
||||
pl.pool = self
|
||||
self.labels[pl.name] = pl
|
||||
full_config.labels[label['name']].pools.append(self)
|
||||
|
||||
|
||||
class KubernetesProviderConfig(ProviderConfig):
|
||||
def __init__(self, driver, provider):
|
||||
|
@ -72,19 +86,9 @@ class KubernetesProviderConfig(ProviderConfig):
|
|||
self.context = self.provider['context']
|
||||
for pool in self.provider.get('pools', []):
|
||||
pp = KubernetesPool()
|
||||
pp.name = pool['name']
|
||||
pp.load(pool, config)
|
||||
pp.provider = self
|
||||
self.pools[pp.name] = pp
|
||||
pp.labels = {}
|
||||
for label in pool.get('labels', []):
|
||||
pl = KubernetesLabel()
|
||||
pl.name = label['name']
|
||||
pl.type = label['type']
|
||||
pl.image = label.get('image')
|
||||
pl.image_pull = label.get('image-pull', 'IfNotPresent')
|
||||
pl.pool = pp
|
||||
pp.labels[pl.name] = pl
|
||||
config.labels[label['name']].pools.append(pp)
|
||||
|
||||
def getSchema(self):
|
||||
k8s_label = {
|
||||
|
@ -94,10 +98,11 @@ class KubernetesProviderConfig(ProviderConfig):
|
|||
'image-pull': str,
|
||||
}
|
||||
|
||||
pool = {
|
||||
pool = ConfigPool.getCommonSchemaDict()
|
||||
pool.update({
|
||||
v.Required('name'): str,
|
||||
v.Required('labels'): [k8s_label],
|
||||
}
|
||||
})
|
||||
|
||||
provider = {
|
||||
v.Required('pools'): [pool],
|
||||
|
|
|
@ -149,6 +149,64 @@ class ProviderPool(ConfigPool):
|
|||
def __repr__(self):
|
||||
return "<ProviderPool %s>" % self.name
|
||||
|
||||
def load(self, pool_config, full_config, provider):
|
||||
'''
|
||||
Load pool configuration options.
|
||||
|
||||
:param dict pool_config: A single pool config section from which we
|
||||
will load the values.
|
||||
:param dict full_config: The full nodepool config.
|
||||
:param OpenStackProviderConfig: The calling provider object.
|
||||
'''
|
||||
super().load(pool_config)
|
||||
|
||||
self.provider = provider
|
||||
self.name = pool_config['name']
|
||||
self.max_cores = pool_config.get('max-cores', math.inf)
|
||||
self.max_ram = pool_config.get('max-ram', math.inf)
|
||||
self.ignore_provider_quota = pool_config.get('ignore-provider-quota',
|
||||
False)
|
||||
self.azs = pool_config.get('availability-zones')
|
||||
self.networks = pool_config.get('networks', [])
|
||||
self.security_groups = pool_config.get('security-groups', [])
|
||||
self.auto_floating_ip = bool(pool_config.get('auto-floating-ip', True))
|
||||
self.host_key_checking = bool(pool_config.get('host-key-checking',
|
||||
True))
|
||||
|
||||
for label in pool_config.get('labels', []):
|
||||
pl = ProviderLabel()
|
||||
pl.name = label['name']
|
||||
pl.pool = self
|
||||
self.labels[pl.name] = pl
|
||||
diskimage = label.get('diskimage', None)
|
||||
if diskimage:
|
||||
pl.diskimage = full_config.diskimages[diskimage]
|
||||
else:
|
||||
pl.diskimage = None
|
||||
cloud_image_name = label.get('cloud-image', None)
|
||||
if cloud_image_name:
|
||||
cloud_image = provider.cloud_images.get(cloud_image_name, None)
|
||||
if not cloud_image:
|
||||
raise ValueError(
|
||||
"cloud-image %s does not exist in provider %s"
|
||||
" but is referenced in label %s" %
|
||||
(cloud_image_name, self.name, pl.name))
|
||||
else:
|
||||
cloud_image = None
|
||||
pl.cloud_image = cloud_image
|
||||
pl.min_ram = label.get('min-ram', 0)
|
||||
pl.flavor_name = label.get('flavor-name', None)
|
||||
pl.key_name = label.get('key-name')
|
||||
pl.console_log = label.get('console-log', False)
|
||||
pl.boot_from_volume = bool(label.get('boot-from-volume',
|
||||
False))
|
||||
pl.volume_size = label.get('volume-size', 50)
|
||||
pl.instance_properties = label.get('instance-properties',
|
||||
None)
|
||||
|
||||
top_label = full_config.labels[pl.name]
|
||||
top_label.pools.append(self)
|
||||
|
||||
|
||||
class OpenStackProviderConfig(ProviderConfig):
|
||||
def __init__(self, driver, provider):
|
||||
|
@ -263,53 +321,8 @@ class OpenStackProviderConfig(ProviderConfig):
|
|||
|
||||
for pool in self.provider.get('pools', []):
|
||||
pp = ProviderPool()
|
||||
pp.name = pool['name']
|
||||
pp.provider = self
|
||||
pp.load(pool, config, self)
|
||||
self.pools[pp.name] = pp
|
||||
pp.max_cores = pool.get('max-cores', math.inf)
|
||||
pp.max_servers = pool.get('max-servers', math.inf)
|
||||
pp.max_ram = pool.get('max-ram', math.inf)
|
||||
pp.ignore_provider_quota = pool.get('ignore-provider-quota', False)
|
||||
pp.azs = pool.get('availability-zones')
|
||||
pp.networks = pool.get('networks', [])
|
||||
pp.security_groups = pool.get('security-groups', [])
|
||||
pp.auto_floating_ip = bool(pool.get('auto-floating-ip', True))
|
||||
pp.host_key_checking = bool(pool.get('host-key-checking', True))
|
||||
pp.node_attributes = pool.get('node-attributes')
|
||||
|
||||
for label in pool.get('labels', []):
|
||||
pl = ProviderLabel()
|
||||
pl.name = label['name']
|
||||
pl.pool = pp
|
||||
pp.labels[pl.name] = pl
|
||||
diskimage = label.get('diskimage', None)
|
||||
if diskimage:
|
||||
pl.diskimage = config.diskimages[diskimage]
|
||||
else:
|
||||
pl.diskimage = None
|
||||
cloud_image_name = label.get('cloud-image', None)
|
||||
if cloud_image_name:
|
||||
cloud_image = self.cloud_images.get(cloud_image_name, None)
|
||||
if not cloud_image:
|
||||
raise ValueError(
|
||||
"cloud-image %s does not exist in provider %s"
|
||||
" but is referenced in label %s" %
|
||||
(cloud_image_name, self.name, pl.name))
|
||||
else:
|
||||
cloud_image = None
|
||||
pl.cloud_image = cloud_image
|
||||
pl.min_ram = label.get('min-ram', 0)
|
||||
pl.flavor_name = label.get('flavor-name', None)
|
||||
pl.key_name = label.get('key-name')
|
||||
pl.console_log = label.get('console-log', False)
|
||||
pl.boot_from_volume = bool(label.get('boot-from-volume',
|
||||
False))
|
||||
pl.volume_size = label.get('volume-size', 50)
|
||||
pl.instance_properties = label.get('instance-properties',
|
||||
None)
|
||||
|
||||
top_label = config.labels[pl.name]
|
||||
top_label.pools.append(pp)
|
||||
|
||||
def getSchema(self):
|
||||
provider_diskimage = {
|
||||
|
@ -358,20 +371,19 @@ class OpenStackProviderConfig(ProviderConfig):
|
|||
v.Any(label_min_ram, label_flavor_name),
|
||||
v.Any(label_diskimage, label_cloud_image))
|
||||
|
||||
pool = {
|
||||
pool = ConfigPool.getCommonSchemaDict()
|
||||
pool.update({
|
||||
'name': str,
|
||||
'networks': [str],
|
||||
'auto-floating-ip': bool,
|
||||
'host-key-checking': bool,
|
||||
'ignore-provider-quota': bool,
|
||||
'max-cores': int,
|
||||
'max-servers': int,
|
||||
'max-ram': int,
|
||||
'labels': [pool_label],
|
||||
'node-attributes': dict,
|
||||
'availability-zones': [str],
|
||||
'security-groups': [str]
|
||||
}
|
||||
})
|
||||
|
||||
return v.Schema({
|
||||
'region-name': str,
|
||||
|
|
|
@ -41,6 +41,33 @@ class StaticPool(ConfigPool):
|
|||
def __repr__(self):
|
||||
return "<StaticPool %s>" % self.name
|
||||
|
||||
def load(self, pool_config, full_config):
|
||||
super().load(pool_config)
|
||||
self.name = pool_config['name']
|
||||
# WARNING: This intentionally changes the type!
|
||||
self.labels = set()
|
||||
for node in pool_config.get('nodes', []):
|
||||
self.nodes.append({
|
||||
'name': node['name'],
|
||||
'labels': as_list(node['labels']),
|
||||
'host-key': as_list(node.get('host-key', [])),
|
||||
'timeout': int(node.get('timeout', 5)),
|
||||
# Read ssh-port values for backward compat, but prefer port
|
||||
'connection-port': int(
|
||||
node.get('connection-port', node.get('ssh-port', 22))),
|
||||
'connection-type': node.get('connection-type', 'ssh'),
|
||||
'username': node.get('username', 'zuul'),
|
||||
'max-parallel-jobs': int(node.get('max-parallel-jobs', 1)),
|
||||
})
|
||||
if isinstance(node['labels'], str):
|
||||
for label in node['labels'].split():
|
||||
self.labels.add(label)
|
||||
full_config.labels[label].pools.append(self)
|
||||
elif isinstance(node['labels'], list):
|
||||
for label in node['labels']:
|
||||
self.labels.add(label)
|
||||
full_config.labels[label].pools.append(self)
|
||||
|
||||
|
||||
class StaticProviderConfig(ProviderConfig):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
@ -65,32 +92,9 @@ class StaticProviderConfig(ProviderConfig):
|
|||
def load(self, config):
|
||||
for pool in self.provider.get('pools', []):
|
||||
pp = StaticPool()
|
||||
pp.name = pool['name']
|
||||
pp.load(pool, config)
|
||||
pp.provider = self
|
||||
self.pools[pp.name] = pp
|
||||
# WARNING: This intentionally changes the type!
|
||||
pp.labels = set()
|
||||
for node in pool.get('nodes', []):
|
||||
pp.nodes.append({
|
||||
'name': node['name'],
|
||||
'labels': as_list(node['labels']),
|
||||
'host-key': as_list(node.get('host-key', [])),
|
||||
'timeout': int(node.get('timeout', 5)),
|
||||
# Read ssh-port values for backward compat, but prefer port
|
||||
'connection-port': int(
|
||||
node.get('connection-port', node.get('ssh-port', 22))),
|
||||
'connection-type': node.get('connection-type', 'ssh'),
|
||||
'username': node.get('username', 'zuul'),
|
||||
'max-parallel-jobs': int(node.get('max-parallel-jobs', 1)),
|
||||
})
|
||||
if isinstance(node['labels'], str):
|
||||
for label in node['labels'].split():
|
||||
pp.labels.add(label)
|
||||
config.labels[label].pools.append(pp)
|
||||
elif isinstance(node['labels'], list):
|
||||
for label in node['labels']:
|
||||
pp.labels.add(label)
|
||||
config.labels[label].pools.append(pp)
|
||||
|
||||
def getSchema(self):
|
||||
pool_node = {
|
||||
|
@ -103,10 +107,11 @@ class StaticProviderConfig(ProviderConfig):
|
|||
'connection-type': str,
|
||||
'max-parallel-jobs': int,
|
||||
}
|
||||
pool = {
|
||||
pool = ConfigPool.getCommonSchemaDict()
|
||||
pool.update({
|
||||
'name': str,
|
||||
'nodes': [pool_node],
|
||||
}
|
||||
})
|
||||
return v.Schema({'pools': [pool]})
|
||||
|
||||
def getSupportedLabels(self, pool_name=None):
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import math
|
||||
import voluptuous as v
|
||||
|
||||
from nodepool.driver import ConfigPool
|
||||
|
@ -20,7 +19,10 @@ from nodepool.driver import ProviderConfig
|
|||
|
||||
|
||||
class TestPool(ConfigPool):
|
||||
pass
|
||||
def load(self, pool_config):
|
||||
super().load(pool_config)
|
||||
self.name = pool_config['name']
|
||||
self.labels = pool_config['labels']
|
||||
|
||||
|
||||
class TestConfig(ProviderConfig):
|
||||
|
@ -43,18 +45,19 @@ class TestConfig(ProviderConfig):
|
|||
self.labels = set()
|
||||
for pool in self.provider.get('pools', []):
|
||||
testpool = TestPool()
|
||||
testpool.name = pool['name']
|
||||
testpool.load(pool)
|
||||
testpool.provider = self
|
||||
testpool.max_servers = pool.get('max-servers', math.inf)
|
||||
testpool.labels = pool['labels']
|
||||
for label in pool['labels']:
|
||||
self.labels.add(label)
|
||||
newconfig.labels[label].pools.append(testpool)
|
||||
self.pools[pool['name']] = testpool
|
||||
|
||||
def getSchema(self):
|
||||
pool = {'name': str,
|
||||
'labels': [str]}
|
||||
pool = ConfigPool.getCommonSchemaDict()
|
||||
pool.update({
|
||||
'name': str,
|
||||
'labels': [str]
|
||||
})
|
||||
return v.Schema({'pools': [pool]})
|
||||
|
||||
def getSupportedLabels(self, pool_name=None):
|
||||
|
|
|
@ -28,11 +28,17 @@ from nodepool.driver.static.config import StaticPool
|
|||
from nodepool.driver.static.config import StaticProviderConfig
|
||||
|
||||
|
||||
class TempConfigPool(ConfigPool):
|
||||
def load(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestConfigComparisons(tests.BaseTestCase):
|
||||
|
||||
def test_ConfigPool(self):
|
||||
a = ConfigPool()
|
||||
b = ConfigPool()
|
||||
|
||||
a = TempConfigPool()
|
||||
b = TempConfigPool()
|
||||
self.assertEqual(a, b)
|
||||
a.max_servers = 5
|
||||
self.assertNotEqual(a, b)
|
||||
|
@ -94,9 +100,9 @@ class TestConfigComparisons(tests.BaseTestCase):
|
|||
a.max_servers = 5
|
||||
self.assertNotEqual(a, b)
|
||||
|
||||
c = ConfigPool()
|
||||
c = TempConfigPool()
|
||||
d = ProviderPool()
|
||||
self.assertNotEqual(c, d)
|
||||
self.assertNotEqual(d, c)
|
||||
|
||||
def test_OpenStackProviderConfig(self):
|
||||
provider = {'name': 'foo'}
|
||||
|
@ -114,7 +120,7 @@ class TestConfigComparisons(tests.BaseTestCase):
|
|||
# intentionally change an attribute of the base class
|
||||
a.max_servers = 5
|
||||
self.assertNotEqual(a, b)
|
||||
c = ConfigPool()
|
||||
c = TempConfigPool()
|
||||
self.assertNotEqual(b, c)
|
||||
|
||||
def test_StaticProviderConfig(self):
|
||||
|
|
Loading…
Reference in New Issue