Extract out common config parsing for ConfigPool
Our driver code is in a less-than-ideal situation where each driver is responsible for parsing config options that are common to all drivers. This change begins to correct that, starting with ConfigPool. It changes the driver API in the following ways: 1) Forces objects derived from ConfigPool to implement a load() method that should call super's method, then handle loading driver specific options from the config. 2) Adds a ConfigPool class method that can be called to get the config schema for the common config options leaving drivers to have to only define the schema for their own config options. Other base config objects will be modeled after this pattern in later changes. Change-Id: I41620590c355cacd2c4fbe6916acfe80f20e3216
This commit is contained in:
parent
f116826d2b
commit
a19dffd916
|
@ -824,7 +824,11 @@ class ConfigValue(object, metaclass=abc.ABCMeta):
|
|||
return not self.__eq__(other)
|
||||
|
||||
|
||||
class ConfigPool(ConfigValue):
|
||||
class ConfigPool(ConfigValue, metaclass=abc.ABCMeta):
|
||||
'''
|
||||
Base class for a single pool as defined in the configuration file.
|
||||
'''
|
||||
|
||||
def __init__(self):
|
||||
self.labels = {}
|
||||
self.max_servers = math.inf
|
||||
|
@ -837,6 +841,40 @@ class ConfigPool(ConfigValue):
|
|||
self.node_attributes == other.node_attributes)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def getCommonSchemaDict(self):
|
||||
'''
|
||||
Return the schema dict for common pool attributes.
|
||||
|
||||
When a driver validates its own configuration schema, it should call
|
||||
this class method to get and include the common pool attributes in
|
||||
the schema.
|
||||
|
||||
The `labels` attribute, though common, can vary its type across
|
||||
drivers so it is not returned in the schema.
|
||||
'''
|
||||
return {
|
||||
'max-servers': int,
|
||||
'node-attributes': dict,
|
||||
}
|
||||
|
||||
@abc.abstractmethod
|
||||
def load(self, pool_config):
|
||||
'''
|
||||
Load pool config options from the parsed configuration file.
|
||||
|
||||
Subclasses are expected to call the parent method so that common
|
||||
configuration values are loaded properly.
|
||||
|
||||
Although `labels` is a common attribute, each driver may
|
||||
define it differently, so we cannot parse that attribute here.
|
||||
|
||||
:param dict pool_config: A single pool config section from which we
|
||||
will load the values.
|
||||
'''
|
||||
self.max_servers = pool_config.get('max-servers', math.inf)
|
||||
self.node_attributes = pool_config.get('node-attributes')
|
||||
|
||||
|
||||
class DriverConfig(ConfigValue):
|
||||
def __init__(self):
|
||||
|
|
|
@ -45,6 +45,20 @@ class KubernetesPool(ConfigPool):
|
|||
def __repr__(self):
|
||||
return "<KubernetesPool %s>" % self.name
|
||||
|
||||
def load(self, pool_config, full_config):
|
||||
super().load(pool_config)
|
||||
self.name = pool_config['name']
|
||||
self.labels = {}
|
||||
for label in pool_config.get('labels', []):
|
||||
pl = KubernetesLabel()
|
||||
pl.name = label['name']
|
||||
pl.type = label['type']
|
||||
pl.image = label.get('image')
|
||||
pl.image_pull = label.get('image-pull', 'IfNotPresent')
|
||||
pl.pool = self
|
||||
self.labels[pl.name] = pl
|
||||
full_config.labels[label['name']].pools.append(self)
|
||||
|
||||
|
||||
class KubernetesProviderConfig(ProviderConfig):
|
||||
def __init__(self, driver, provider):
|
||||
|
@ -72,19 +86,9 @@ class KubernetesProviderConfig(ProviderConfig):
|
|||
self.context = self.provider['context']
|
||||
for pool in self.provider.get('pools', []):
|
||||
pp = KubernetesPool()
|
||||
pp.name = pool['name']
|
||||
pp.load(pool, config)
|
||||
pp.provider = self
|
||||
self.pools[pp.name] = pp
|
||||
pp.labels = {}
|
||||
for label in pool.get('labels', []):
|
||||
pl = KubernetesLabel()
|
||||
pl.name = label['name']
|
||||
pl.type = label['type']
|
||||
pl.image = label.get('image')
|
||||
pl.image_pull = label.get('image-pull', 'IfNotPresent')
|
||||
pl.pool = pp
|
||||
pp.labels[pl.name] = pl
|
||||
config.labels[label['name']].pools.append(pp)
|
||||
|
||||
def getSchema(self):
|
||||
k8s_label = {
|
||||
|
@ -94,10 +98,11 @@ class KubernetesProviderConfig(ProviderConfig):
|
|||
'image-pull': str,
|
||||
}
|
||||
|
||||
pool = {
|
||||
pool = ConfigPool.getCommonSchemaDict()
|
||||
pool.update({
|
||||
v.Required('name'): str,
|
||||
v.Required('labels'): [k8s_label],
|
||||
}
|
||||
})
|
||||
|
||||
provider = {
|
||||
v.Required('pools'): [pool],
|
||||
|
|
|
@ -149,6 +149,64 @@ class ProviderPool(ConfigPool):
|
|||
def __repr__(self):
|
||||
return "<ProviderPool %s>" % self.name
|
||||
|
||||
def load(self, pool_config, full_config, provider):
|
||||
'''
|
||||
Load pool configuration options.
|
||||
|
||||
:param dict pool_config: A single pool config section from which we
|
||||
will load the values.
|
||||
:param dict full_config: The full nodepool config.
|
||||
:param OpenStackProviderConfig: The calling provider object.
|
||||
'''
|
||||
super().load(pool_config)
|
||||
|
||||
self.provider = provider
|
||||
self.name = pool_config['name']
|
||||
self.max_cores = pool_config.get('max-cores', math.inf)
|
||||
self.max_ram = pool_config.get('max-ram', math.inf)
|
||||
self.ignore_provider_quota = pool_config.get('ignore-provider-quota',
|
||||
False)
|
||||
self.azs = pool_config.get('availability-zones')
|
||||
self.networks = pool_config.get('networks', [])
|
||||
self.security_groups = pool_config.get('security-groups', [])
|
||||
self.auto_floating_ip = bool(pool_config.get('auto-floating-ip', True))
|
||||
self.host_key_checking = bool(pool_config.get('host-key-checking',
|
||||
True))
|
||||
|
||||
for label in pool_config.get('labels', []):
|
||||
pl = ProviderLabel()
|
||||
pl.name = label['name']
|
||||
pl.pool = self
|
||||
self.labels[pl.name] = pl
|
||||
diskimage = label.get('diskimage', None)
|
||||
if diskimage:
|
||||
pl.diskimage = full_config.diskimages[diskimage]
|
||||
else:
|
||||
pl.diskimage = None
|
||||
cloud_image_name = label.get('cloud-image', None)
|
||||
if cloud_image_name:
|
||||
cloud_image = provider.cloud_images.get(cloud_image_name, None)
|
||||
if not cloud_image:
|
||||
raise ValueError(
|
||||
"cloud-image %s does not exist in provider %s"
|
||||
" but is referenced in label %s" %
|
||||
(cloud_image_name, self.name, pl.name))
|
||||
else:
|
||||
cloud_image = None
|
||||
pl.cloud_image = cloud_image
|
||||
pl.min_ram = label.get('min-ram', 0)
|
||||
pl.flavor_name = label.get('flavor-name', None)
|
||||
pl.key_name = label.get('key-name')
|
||||
pl.console_log = label.get('console-log', False)
|
||||
pl.boot_from_volume = bool(label.get('boot-from-volume',
|
||||
False))
|
||||
pl.volume_size = label.get('volume-size', 50)
|
||||
pl.instance_properties = label.get('instance-properties',
|
||||
None)
|
||||
|
||||
top_label = full_config.labels[pl.name]
|
||||
top_label.pools.append(self)
|
||||
|
||||
|
||||
class OpenStackProviderConfig(ProviderConfig):
|
||||
def __init__(self, driver, provider):
|
||||
|
@ -263,53 +321,8 @@ class OpenStackProviderConfig(ProviderConfig):
|
|||
|
||||
for pool in self.provider.get('pools', []):
|
||||
pp = ProviderPool()
|
||||
pp.name = pool['name']
|
||||
pp.provider = self
|
||||
pp.load(pool, config, self)
|
||||
self.pools[pp.name] = pp
|
||||
pp.max_cores = pool.get('max-cores', math.inf)
|
||||
pp.max_servers = pool.get('max-servers', math.inf)
|
||||
pp.max_ram = pool.get('max-ram', math.inf)
|
||||
pp.ignore_provider_quota = pool.get('ignore-provider-quota', False)
|
||||
pp.azs = pool.get('availability-zones')
|
||||
pp.networks = pool.get('networks', [])
|
||||
pp.security_groups = pool.get('security-groups', [])
|
||||
pp.auto_floating_ip = bool(pool.get('auto-floating-ip', True))
|
||||
pp.host_key_checking = bool(pool.get('host-key-checking', True))
|
||||
pp.node_attributes = pool.get('node-attributes')
|
||||
|
||||
for label in pool.get('labels', []):
|
||||
pl = ProviderLabel()
|
||||
pl.name = label['name']
|
||||
pl.pool = pp
|
||||
pp.labels[pl.name] = pl
|
||||
diskimage = label.get('diskimage', None)
|
||||
if diskimage:
|
||||
pl.diskimage = config.diskimages[diskimage]
|
||||
else:
|
||||
pl.diskimage = None
|
||||
cloud_image_name = label.get('cloud-image', None)
|
||||
if cloud_image_name:
|
||||
cloud_image = self.cloud_images.get(cloud_image_name, None)
|
||||
if not cloud_image:
|
||||
raise ValueError(
|
||||
"cloud-image %s does not exist in provider %s"
|
||||
" but is referenced in label %s" %
|
||||
(cloud_image_name, self.name, pl.name))
|
||||
else:
|
||||
cloud_image = None
|
||||
pl.cloud_image = cloud_image
|
||||
pl.min_ram = label.get('min-ram', 0)
|
||||
pl.flavor_name = label.get('flavor-name', None)
|
||||
pl.key_name = label.get('key-name')
|
||||
pl.console_log = label.get('console-log', False)
|
||||
pl.boot_from_volume = bool(label.get('boot-from-volume',
|
||||
False))
|
||||
pl.volume_size = label.get('volume-size', 50)
|
||||
pl.instance_properties = label.get('instance-properties',
|
||||
None)
|
||||
|
||||
top_label = config.labels[pl.name]
|
||||
top_label.pools.append(pp)
|
||||
|
||||
def getSchema(self):
|
||||
provider_diskimage = {
|
||||
|
@ -358,20 +371,19 @@ class OpenStackProviderConfig(ProviderConfig):
|
|||
v.Any(label_min_ram, label_flavor_name),
|
||||
v.Any(label_diskimage, label_cloud_image))
|
||||
|
||||
pool = {
|
||||
pool = ConfigPool.getCommonSchemaDict()
|
||||
pool.update({
|
||||
'name': str,
|
||||
'networks': [str],
|
||||
'auto-floating-ip': bool,
|
||||
'host-key-checking': bool,
|
||||
'ignore-provider-quota': bool,
|
||||
'max-cores': int,
|
||||
'max-servers': int,
|
||||
'max-ram': int,
|
||||
'labels': [pool_label],
|
||||
'node-attributes': dict,
|
||||
'availability-zones': [str],
|
||||
'security-groups': [str]
|
||||
}
|
||||
})
|
||||
|
||||
return v.Schema({
|
||||
'region-name': str,
|
||||
|
|
|
@ -41,6 +41,33 @@ class StaticPool(ConfigPool):
|
|||
def __repr__(self):
|
||||
return "<StaticPool %s>" % self.name
|
||||
|
||||
def load(self, pool_config, full_config):
|
||||
super().load(pool_config)
|
||||
self.name = pool_config['name']
|
||||
# WARNING: This intentionally changes the type!
|
||||
self.labels = set()
|
||||
for node in pool_config.get('nodes', []):
|
||||
self.nodes.append({
|
||||
'name': node['name'],
|
||||
'labels': as_list(node['labels']),
|
||||
'host-key': as_list(node.get('host-key', [])),
|
||||
'timeout': int(node.get('timeout', 5)),
|
||||
# Read ssh-port values for backward compat, but prefer port
|
||||
'connection-port': int(
|
||||
node.get('connection-port', node.get('ssh-port', 22))),
|
||||
'connection-type': node.get('connection-type', 'ssh'),
|
||||
'username': node.get('username', 'zuul'),
|
||||
'max-parallel-jobs': int(node.get('max-parallel-jobs', 1)),
|
||||
})
|
||||
if isinstance(node['labels'], str):
|
||||
for label in node['labels'].split():
|
||||
self.labels.add(label)
|
||||
full_config.labels[label].pools.append(self)
|
||||
elif isinstance(node['labels'], list):
|
||||
for label in node['labels']:
|
||||
self.labels.add(label)
|
||||
full_config.labels[label].pools.append(self)
|
||||
|
||||
|
||||
class StaticProviderConfig(ProviderConfig):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
@ -65,32 +92,9 @@ class StaticProviderConfig(ProviderConfig):
|
|||
def load(self, config):
|
||||
for pool in self.provider.get('pools', []):
|
||||
pp = StaticPool()
|
||||
pp.name = pool['name']
|
||||
pp.load(pool, config)
|
||||
pp.provider = self
|
||||
self.pools[pp.name] = pp
|
||||
# WARNING: This intentionally changes the type!
|
||||
pp.labels = set()
|
||||
for node in pool.get('nodes', []):
|
||||
pp.nodes.append({
|
||||
'name': node['name'],
|
||||
'labels': as_list(node['labels']),
|
||||
'host-key': as_list(node.get('host-key', [])),
|
||||
'timeout': int(node.get('timeout', 5)),
|
||||
# Read ssh-port values for backward compat, but prefer port
|
||||
'connection-port': int(
|
||||
node.get('connection-port', node.get('ssh-port', 22))),
|
||||
'connection-type': node.get('connection-type', 'ssh'),
|
||||
'username': node.get('username', 'zuul'),
|
||||
'max-parallel-jobs': int(node.get('max-parallel-jobs', 1)),
|
||||
})
|
||||
if isinstance(node['labels'], str):
|
||||
for label in node['labels'].split():
|
||||
pp.labels.add(label)
|
||||
config.labels[label].pools.append(pp)
|
||||
elif isinstance(node['labels'], list):
|
||||
for label in node['labels']:
|
||||
pp.labels.add(label)
|
||||
config.labels[label].pools.append(pp)
|
||||
|
||||
def getSchema(self):
|
||||
pool_node = {
|
||||
|
@ -103,10 +107,11 @@ class StaticProviderConfig(ProviderConfig):
|
|||
'connection-type': str,
|
||||
'max-parallel-jobs': int,
|
||||
}
|
||||
pool = {
|
||||
pool = ConfigPool.getCommonSchemaDict()
|
||||
pool.update({
|
||||
'name': str,
|
||||
'nodes': [pool_node],
|
||||
}
|
||||
})
|
||||
return v.Schema({'pools': [pool]})
|
||||
|
||||
def getSupportedLabels(self, pool_name=None):
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import math
|
||||
import voluptuous as v
|
||||
|
||||
from nodepool.driver import ConfigPool
|
||||
|
@ -20,7 +19,10 @@ from nodepool.driver import ProviderConfig
|
|||
|
||||
|
||||
class TestPool(ConfigPool):
|
||||
pass
|
||||
def load(self, pool_config):
|
||||
super().load(pool_config)
|
||||
self.name = pool_config['name']
|
||||
self.labels = pool_config['labels']
|
||||
|
||||
|
||||
class TestConfig(ProviderConfig):
|
||||
|
@ -43,18 +45,19 @@ class TestConfig(ProviderConfig):
|
|||
self.labels = set()
|
||||
for pool in self.provider.get('pools', []):
|
||||
testpool = TestPool()
|
||||
testpool.name = pool['name']
|
||||
testpool.load(pool)
|
||||
testpool.provider = self
|
||||
testpool.max_servers = pool.get('max-servers', math.inf)
|
||||
testpool.labels = pool['labels']
|
||||
for label in pool['labels']:
|
||||
self.labels.add(label)
|
||||
newconfig.labels[label].pools.append(testpool)
|
||||
self.pools[pool['name']] = testpool
|
||||
|
||||
def getSchema(self):
|
||||
pool = {'name': str,
|
||||
'labels': [str]}
|
||||
pool = ConfigPool.getCommonSchemaDict()
|
||||
pool.update({
|
||||
'name': str,
|
||||
'labels': [str]
|
||||
})
|
||||
return v.Schema({'pools': [pool]})
|
||||
|
||||
def getSupportedLabels(self, pool_name=None):
|
||||
|
|
|
@ -28,11 +28,17 @@ from nodepool.driver.static.config import StaticPool
|
|||
from nodepool.driver.static.config import StaticProviderConfig
|
||||
|
||||
|
||||
class TempConfigPool(ConfigPool):
|
||||
def load(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestConfigComparisons(tests.BaseTestCase):
|
||||
|
||||
def test_ConfigPool(self):
|
||||
a = ConfigPool()
|
||||
b = ConfigPool()
|
||||
|
||||
a = TempConfigPool()
|
||||
b = TempConfigPool()
|
||||
self.assertEqual(a, b)
|
||||
a.max_servers = 5
|
||||
self.assertNotEqual(a, b)
|
||||
|
@ -94,9 +100,9 @@ class TestConfigComparisons(tests.BaseTestCase):
|
|||
a.max_servers = 5
|
||||
self.assertNotEqual(a, b)
|
||||
|
||||
c = ConfigPool()
|
||||
c = TempConfigPool()
|
||||
d = ProviderPool()
|
||||
self.assertNotEqual(c, d)
|
||||
self.assertNotEqual(d, c)
|
||||
|
||||
def test_OpenStackProviderConfig(self):
|
||||
provider = {'name': 'foo'}
|
||||
|
@ -114,7 +120,7 @@ class TestConfigComparisons(tests.BaseTestCase):
|
|||
# intentionally change an attribute of the base class
|
||||
a.max_servers = 5
|
||||
self.assertNotEqual(a, b)
|
||||
c = ConfigPool()
|
||||
c = TempConfigPool()
|
||||
self.assertNotEqual(b, c)
|
||||
|
||||
def test_StaticProviderConfig(self):
|
||||
|
|
Loading…
Reference in New Issue