Extract out common config parsing for ConfigPool

Our driver code is in a less-than-ideal situation where each driver
is responsible for parsing config options that are common to all
drivers. This change begins to correct that, starting with ConfigPool.
It changes the driver API in the following ways:

1) Forces objects derived from ConfigPool to implement a load() method
   that should call super's method, then handle loading driver specific
   options from the config.

2) Adds a ConfigPool class method that can be called to get the config
   schema for the common config options leaving drivers to have to only
   define the schema for their own config options.

Other base config objects will be modeled after this pattern in
later changes.

Change-Id: I41620590c355cacd2c4fbe6916acfe80f20e3216
This commit is contained in:
David Shrewsbury 2018-12-03 12:10:28 -05:00
parent f116826d2b
commit a19dffd916
6 changed files with 171 additions and 102 deletions

View File

@ -824,7 +824,11 @@ class ConfigValue(object, metaclass=abc.ABCMeta):
return not self.__eq__(other)
class ConfigPool(ConfigValue):
class ConfigPool(ConfigValue, metaclass=abc.ABCMeta):
'''
Base class for a single pool as defined in the configuration file.
'''
def __init__(self):
self.labels = {}
self.max_servers = math.inf
@ -837,6 +841,40 @@ class ConfigPool(ConfigValue):
self.node_attributes == other.node_attributes)
return False
@classmethod
def getCommonSchemaDict(self):
'''
Return the schema dict for common pool attributes.
When a driver validates its own configuration schema, it should call
this class method to get and include the common pool attributes in
the schema.
The `labels` attribute, though common, can vary its type across
drivers so it is not returned in the schema.
'''
return {
'max-servers': int,
'node-attributes': dict,
}
@abc.abstractmethod
def load(self, pool_config):
'''
Load pool config options from the parsed configuration file.
Subclasses are expected to call the parent method so that common
configuration values are loaded properly.
Although `labels` is a common attribute, each driver may
define it differently, so we cannot parse that attribute here.
:param dict pool_config: A single pool config section from which we
will load the values.
'''
self.max_servers = pool_config.get('max-servers', math.inf)
self.node_attributes = pool_config.get('node-attributes')
class DriverConfig(ConfigValue):
def __init__(self):

View File

@ -45,6 +45,20 @@ class KubernetesPool(ConfigPool):
def __repr__(self):
return "<KubernetesPool %s>" % self.name
def load(self, pool_config, full_config):
super().load(pool_config)
self.name = pool_config['name']
self.labels = {}
for label in pool_config.get('labels', []):
pl = KubernetesLabel()
pl.name = label['name']
pl.type = label['type']
pl.image = label.get('image')
pl.image_pull = label.get('image-pull', 'IfNotPresent')
pl.pool = self
self.labels[pl.name] = pl
full_config.labels[label['name']].pools.append(self)
class KubernetesProviderConfig(ProviderConfig):
def __init__(self, driver, provider):
@ -72,19 +86,9 @@ class KubernetesProviderConfig(ProviderConfig):
self.context = self.provider['context']
for pool in self.provider.get('pools', []):
pp = KubernetesPool()
pp.name = pool['name']
pp.load(pool, config)
pp.provider = self
self.pools[pp.name] = pp
pp.labels = {}
for label in pool.get('labels', []):
pl = KubernetesLabel()
pl.name = label['name']
pl.type = label['type']
pl.image = label.get('image')
pl.image_pull = label.get('image-pull', 'IfNotPresent')
pl.pool = pp
pp.labels[pl.name] = pl
config.labels[label['name']].pools.append(pp)
def getSchema(self):
k8s_label = {
@ -94,10 +98,11 @@ class KubernetesProviderConfig(ProviderConfig):
'image-pull': str,
}
pool = {
pool = ConfigPool.getCommonSchemaDict()
pool.update({
v.Required('name'): str,
v.Required('labels'): [k8s_label],
}
})
provider = {
v.Required('pools'): [pool],

View File

@ -149,6 +149,64 @@ class ProviderPool(ConfigPool):
def __repr__(self):
return "<ProviderPool %s>" % self.name
def load(self, pool_config, full_config, provider):
'''
Load pool configuration options.
:param dict pool_config: A single pool config section from which we
will load the values.
:param dict full_config: The full nodepool config.
:param OpenStackProviderConfig: The calling provider object.
'''
super().load(pool_config)
self.provider = provider
self.name = pool_config['name']
self.max_cores = pool_config.get('max-cores', math.inf)
self.max_ram = pool_config.get('max-ram', math.inf)
self.ignore_provider_quota = pool_config.get('ignore-provider-quota',
False)
self.azs = pool_config.get('availability-zones')
self.networks = pool_config.get('networks', [])
self.security_groups = pool_config.get('security-groups', [])
self.auto_floating_ip = bool(pool_config.get('auto-floating-ip', True))
self.host_key_checking = bool(pool_config.get('host-key-checking',
True))
for label in pool_config.get('labels', []):
pl = ProviderLabel()
pl.name = label['name']
pl.pool = self
self.labels[pl.name] = pl
diskimage = label.get('diskimage', None)
if diskimage:
pl.diskimage = full_config.diskimages[diskimage]
else:
pl.diskimage = None
cloud_image_name = label.get('cloud-image', None)
if cloud_image_name:
cloud_image = provider.cloud_images.get(cloud_image_name, None)
if not cloud_image:
raise ValueError(
"cloud-image %s does not exist in provider %s"
" but is referenced in label %s" %
(cloud_image_name, self.name, pl.name))
else:
cloud_image = None
pl.cloud_image = cloud_image
pl.min_ram = label.get('min-ram', 0)
pl.flavor_name = label.get('flavor-name', None)
pl.key_name = label.get('key-name')
pl.console_log = label.get('console-log', False)
pl.boot_from_volume = bool(label.get('boot-from-volume',
False))
pl.volume_size = label.get('volume-size', 50)
pl.instance_properties = label.get('instance-properties',
None)
top_label = full_config.labels[pl.name]
top_label.pools.append(self)
class OpenStackProviderConfig(ProviderConfig):
def __init__(self, driver, provider):
@ -263,53 +321,8 @@ class OpenStackProviderConfig(ProviderConfig):
for pool in self.provider.get('pools', []):
pp = ProviderPool()
pp.name = pool['name']
pp.provider = self
pp.load(pool, config, self)
self.pools[pp.name] = pp
pp.max_cores = pool.get('max-cores', math.inf)
pp.max_servers = pool.get('max-servers', math.inf)
pp.max_ram = pool.get('max-ram', math.inf)
pp.ignore_provider_quota = pool.get('ignore-provider-quota', False)
pp.azs = pool.get('availability-zones')
pp.networks = pool.get('networks', [])
pp.security_groups = pool.get('security-groups', [])
pp.auto_floating_ip = bool(pool.get('auto-floating-ip', True))
pp.host_key_checking = bool(pool.get('host-key-checking', True))
pp.node_attributes = pool.get('node-attributes')
for label in pool.get('labels', []):
pl = ProviderLabel()
pl.name = label['name']
pl.pool = pp
pp.labels[pl.name] = pl
diskimage = label.get('diskimage', None)
if diskimage:
pl.diskimage = config.diskimages[diskimage]
else:
pl.diskimage = None
cloud_image_name = label.get('cloud-image', None)
if cloud_image_name:
cloud_image = self.cloud_images.get(cloud_image_name, None)
if not cloud_image:
raise ValueError(
"cloud-image %s does not exist in provider %s"
" but is referenced in label %s" %
(cloud_image_name, self.name, pl.name))
else:
cloud_image = None
pl.cloud_image = cloud_image
pl.min_ram = label.get('min-ram', 0)
pl.flavor_name = label.get('flavor-name', None)
pl.key_name = label.get('key-name')
pl.console_log = label.get('console-log', False)
pl.boot_from_volume = bool(label.get('boot-from-volume',
False))
pl.volume_size = label.get('volume-size', 50)
pl.instance_properties = label.get('instance-properties',
None)
top_label = config.labels[pl.name]
top_label.pools.append(pp)
def getSchema(self):
provider_diskimage = {
@ -358,20 +371,19 @@ class OpenStackProviderConfig(ProviderConfig):
v.Any(label_min_ram, label_flavor_name),
v.Any(label_diskimage, label_cloud_image))
pool = {
pool = ConfigPool.getCommonSchemaDict()
pool.update({
'name': str,
'networks': [str],
'auto-floating-ip': bool,
'host-key-checking': bool,
'ignore-provider-quota': bool,
'max-cores': int,
'max-servers': int,
'max-ram': int,
'labels': [pool_label],
'node-attributes': dict,
'availability-zones': [str],
'security-groups': [str]
}
})
return v.Schema({
'region-name': str,

View File

@ -41,6 +41,33 @@ class StaticPool(ConfigPool):
def __repr__(self):
return "<StaticPool %s>" % self.name
def load(self, pool_config, full_config):
super().load(pool_config)
self.name = pool_config['name']
# WARNING: This intentionally changes the type!
self.labels = set()
for node in pool_config.get('nodes', []):
self.nodes.append({
'name': node['name'],
'labels': as_list(node['labels']),
'host-key': as_list(node.get('host-key', [])),
'timeout': int(node.get('timeout', 5)),
# Read ssh-port values for backward compat, but prefer port
'connection-port': int(
node.get('connection-port', node.get('ssh-port', 22))),
'connection-type': node.get('connection-type', 'ssh'),
'username': node.get('username', 'zuul'),
'max-parallel-jobs': int(node.get('max-parallel-jobs', 1)),
})
if isinstance(node['labels'], str):
for label in node['labels'].split():
self.labels.add(label)
full_config.labels[label].pools.append(self)
elif isinstance(node['labels'], list):
for label in node['labels']:
self.labels.add(label)
full_config.labels[label].pools.append(self)
class StaticProviderConfig(ProviderConfig):
def __init__(self, *args, **kwargs):
@ -65,32 +92,9 @@ class StaticProviderConfig(ProviderConfig):
def load(self, config):
for pool in self.provider.get('pools', []):
pp = StaticPool()
pp.name = pool['name']
pp.load(pool, config)
pp.provider = self
self.pools[pp.name] = pp
# WARNING: This intentionally changes the type!
pp.labels = set()
for node in pool.get('nodes', []):
pp.nodes.append({
'name': node['name'],
'labels': as_list(node['labels']),
'host-key': as_list(node.get('host-key', [])),
'timeout': int(node.get('timeout', 5)),
# Read ssh-port values for backward compat, but prefer port
'connection-port': int(
node.get('connection-port', node.get('ssh-port', 22))),
'connection-type': node.get('connection-type', 'ssh'),
'username': node.get('username', 'zuul'),
'max-parallel-jobs': int(node.get('max-parallel-jobs', 1)),
})
if isinstance(node['labels'], str):
for label in node['labels'].split():
pp.labels.add(label)
config.labels[label].pools.append(pp)
elif isinstance(node['labels'], list):
for label in node['labels']:
pp.labels.add(label)
config.labels[label].pools.append(pp)
def getSchema(self):
pool_node = {
@ -103,10 +107,11 @@ class StaticProviderConfig(ProviderConfig):
'connection-type': str,
'max-parallel-jobs': int,
}
pool = {
pool = ConfigPool.getCommonSchemaDict()
pool.update({
'name': str,
'nodes': [pool_node],
}
})
return v.Schema({'pools': [pool]})
def getSupportedLabels(self, pool_name=None):

View File

@ -12,7 +12,6 @@
# License for the specific language governing permissions and limitations
# under the License.
import math
import voluptuous as v
from nodepool.driver import ConfigPool
@ -20,7 +19,10 @@ from nodepool.driver import ProviderConfig
class TestPool(ConfigPool):
pass
def load(self, pool_config):
super().load(pool_config)
self.name = pool_config['name']
self.labels = pool_config['labels']
class TestConfig(ProviderConfig):
@ -43,18 +45,19 @@ class TestConfig(ProviderConfig):
self.labels = set()
for pool in self.provider.get('pools', []):
testpool = TestPool()
testpool.name = pool['name']
testpool.load(pool)
testpool.provider = self
testpool.max_servers = pool.get('max-servers', math.inf)
testpool.labels = pool['labels']
for label in pool['labels']:
self.labels.add(label)
newconfig.labels[label].pools.append(testpool)
self.pools[pool['name']] = testpool
def getSchema(self):
pool = {'name': str,
'labels': [str]}
pool = ConfigPool.getCommonSchemaDict()
pool.update({
'name': str,
'labels': [str]
})
return v.Schema({'pools': [pool]})
def getSupportedLabels(self, pool_name=None):

View File

@ -28,11 +28,17 @@ from nodepool.driver.static.config import StaticPool
from nodepool.driver.static.config import StaticProviderConfig
class TempConfigPool(ConfigPool):
def load(self):
pass
class TestConfigComparisons(tests.BaseTestCase):
def test_ConfigPool(self):
a = ConfigPool()
b = ConfigPool()
a = TempConfigPool()
b = TempConfigPool()
self.assertEqual(a, b)
a.max_servers = 5
self.assertNotEqual(a, b)
@ -94,9 +100,9 @@ class TestConfigComparisons(tests.BaseTestCase):
a.max_servers = 5
self.assertNotEqual(a, b)
c = ConfigPool()
c = TempConfigPool()
d = ProviderPool()
self.assertNotEqual(c, d)
self.assertNotEqual(d, c)
def test_OpenStackProviderConfig(self):
provider = {'name': 'foo'}
@ -114,7 +120,7 @@ class TestConfigComparisons(tests.BaseTestCase):
# intentionally change an attribute of the base class
a.max_servers = 5
self.assertNotEqual(a, b)
c = ConfigPool()
c = TempConfigPool()
self.assertNotEqual(b, c)
def test_StaticProviderConfig(self):