summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJonah BrĂ¼chert <jbb@kaidan.im>2024-04-16 18:02:46 +0200
committerJonah BrĂ¼chert <jbb@kaidan.im>2024-04-16 18:41:35 +0200
commitb604a66731ec21a33bd28ad92114c1c4a8b6755c (patch)
tree8b9f3dc032c2734cb6f6f54334771ee3181b8396
parent15f2aab7f1053a46d6d87a7c7f193767516c16cc (diff)
downloadtools-b604a66731ec21a33bd28ad92114c1c4a8b6755c.tar.gz
tools-b604a66731ec21a33bd28ad92114c1c4a8b6755c.tar.bz2
tools-b604a66731ec21a33bd28ad92114c1c4a8b6755c.zip
Add type hints
-rw-r--r--.gitignore1
-rwxr-xr-xbin/hostinfo31
-rw-r--r--hostinfo/prefix.py9
-rw-r--r--hostinfo/printer.py68
4 files changed, 57 insertions, 52 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..751553b
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1 @@
+*.bak
diff --git a/bin/hostinfo b/bin/hostinfo
index a97a8d8..b85ccc8 100755
--- a/bin/hostinfo
+++ b/bin/hostinfo
@@ -8,6 +8,8 @@ import os
import pkg_resources
from dns import resolver, reversename
+from typing import Optional
+
OWN_DIRECTORY = os.path.dirname(os.path.abspath(os.path.realpath(__file__)))
LIB = os.path.join(OWN_DIRECTORY, '..')
if os.path.exists(os.path.join(LIB, 'hostinfo')):
@@ -16,10 +18,12 @@ if os.path.exists(os.path.join(LIB, 'hostinfo')):
from hostinfo import printer
from hostinfo import utils
+
def _get_data(path):
stream = file(path, 'r')
return yaml.load(stream)
+
def _match_key(data, keys):
if data is None:
return None
@@ -53,6 +57,7 @@ def _match_key(data, keys):
return None
+
def _match(host, search_key, search_value, negate):
search_keys = search_key.split('.')
@@ -71,6 +76,7 @@ def _match(host, search_key, search_value, negate):
return (search_key, result)
return (search_key, None)
+
def _parse_search(search):
if search[0] != '?':
sys.stderr.write("Invalid search string.")
@@ -90,10 +96,13 @@ def _parse_search(search):
return (search, value, negate)
-def print_search(basepath, flags, search, filter_key=None):
+
+def print_search(basepath: str, flags: argparse.Namespace,
+ search: str, filter_key: Optional[str] = None):
def _get_label(host):
if flags.short:
- return host.replace('.spline.inf.fu-berlin.de','')
+ return host.replace('.spline.inf.fu-berlin.de',
+ '')
return host
metadata = os.path.join(basepath, 'metadata', 'hosts')
@@ -124,11 +133,13 @@ def print_search(basepath, flags, search, filter_key=None):
p.info(key, label=_get_label(host),
maxlength=max(length), force=True)
+
def print_info(path, flags, key=None):
data = _get_data(path)
p = printer.Printer(data, flags)
p.info(key)
+
def print_keys(path):
def _print_keys(data, prefix = ''):
if isinstance(data, str):
@@ -152,10 +163,11 @@ def print_keys(path):
data = _get_data(path)
_print_keys(data)
-def print_hosts(path, short):
+
+def print_hosts(path: str, short: bool):
metadata = os.path.join(path, 'metadata', 'hosts')
if os.path.exists(metadata):
- hosts = yaml.load(file(metadata, 'r'))
+ hosts = yaml.safe_load(open(metadata, 'r'))
if 'hosts' in hosts:
for host in hosts['hosts']:
if short:
@@ -167,7 +179,8 @@ def print_hosts(path, short):
sys.stderr.write("'%s' not found!\n" % metadata)
return False
-def find_host(basepath, host):
+
+def find_host(basepath: str, host: str):
path = os.path.join(basepath, host)
if os.path.exists(path):
return path
@@ -191,6 +204,7 @@ def find_host(basepath, host):
return None
+
def print_version_and_exit():
ver = None
try:
@@ -205,7 +219,8 @@ def print_version_and_exit():
print(("hostinfo-tools %s" % ver))
sys.exit(0)
-def main():
+
+def main() -> None:
basepath = '/usr/local/share/hostinfo'
if 'HOSTINFO_PATH' in os.environ and os.environ['HOSTINFO_PATH'] != '':
basepath = os.environ['HOSTINFO_PATH']
@@ -256,7 +271,8 @@ def main():
if args.name.startswith('?'):
# search
- print_search(basepath, search=args.name, filter_key=args.filter, flags=args)
+ print_search(basepath, search=args.name, filter_key=args.filter,
+ flags=args)
else:
# info
path = find_host(basepath, args.name)
@@ -273,5 +289,6 @@ def main():
sys.exit(0)
+
if __name__ == '__main__':
main()
diff --git a/hostinfo/prefix.py b/hostinfo/prefix.py
index 8271d67..7233c48 100644
--- a/hostinfo/prefix.py
+++ b/hostinfo/prefix.py
@@ -1,9 +1,12 @@
# -*- coding: utf-8 -*-
+class Flags:
+ oneline: bool = False
+ nospaces: bool = False
class Printer:
- flags = list()
+ flags = Flags()
def __init__(self, full_key='', printer=None):
if printer is None:
@@ -20,7 +23,7 @@ class Printer:
def set_label(self, label='', maxlength=0):
self.label = self._get_label(label, maxlength)
- def _get_label(self, label, maxlength):
+ def _get_label(self, label: str, maxlength: int):
if label == '':
return label
@@ -33,7 +36,7 @@ class Printer:
else:
return label.ljust(maxlength+2)
- def pprint(self, data):
+ def pprint(self, data: str):
self.output("%s%s" % (self.label, data))
self.has_output = True
if not self.empty:
diff --git a/hostinfo/printer.py b/hostinfo/printer.py
index e467afe..16c0249 100644
--- a/hostinfo/printer.py
+++ b/hostinfo/printer.py
@@ -4,6 +4,8 @@
from hostinfo import prefix
from hostinfo import utils
+from typing import Iterable, Optional
+
def _get_full_key(prev_key, key):
if prev_key == '':
@@ -11,7 +13,7 @@ def _get_full_key(prev_key, key):
return "%s.%s" % (prev_key, key)
-def _sort_with_list(iterable, sort):
+def _sort_with_list(iterable: Iterable, sort: list[str]):
def helper(value):
if sort is None:
return value
@@ -24,7 +26,8 @@ def _sort_with_list(iterable, sort):
return sorted(iterable, key=helper)
-def _space(filter_key, full_key, printer, force=False):
+def _space(filter_key: Optional[str], full_key: str,
+ printer: prefix.Printer, force=False):
if filter_key is None and full_key == '':
printer.space(force)
@@ -49,45 +52,19 @@ class Printer:
self.flags = flags
prefix.Printer.flags = flags
- def cb_print_addresses(self, value, full_key, filter_key):
- def _print_ip(address):
- return '%s/%s' % (address['address'], address['netmask'])
-
- display_check = self._is_group_displayed(full_key, filter_key)
- return utils.group_by(value, 'interface', None, display_check, _print_ip)
-
- def cb_print_ports(self, value, full_key, filter_key):
- def _print_port(port):
- if port['proto'] in ['tcp6', 'udp6']:
- return '(%s) [%s]:%s' % (port['proto'].replace('6', ''),
- port['ip'], port['port'])
- return '(%s) %s:%s' % (port['proto'], port['ip'], port['port'])
-
- display_check = self._is_group_displayed(full_key, filter_key)
- return (utils.group_by(value, 'process', 'UNKNOWN', display_check, _print_port),
- ['sshd', 'nrpe', 'munin-node'])
-
- def cb_print_vserver(self, value, full_key, filter_key):
- if value == 'guest' and 'vserver_host' in self.data and \
- self.data['vserver_host'] is not None:
- return 'guest running on %s' % self.data['vserver_host']
- else:
- return value
-
- def _should_display(self, full_key, filter_key=None):
+ def _should_display(self, full_key,
+ filter_key: Optional[str] = None):
if full_key not in self.ignore:
return True
if filter_key is not None and filter_key.startswith(full_key):
return True
return False
- def _is_group_displayed(self, prev_key, filter_key):
- return (lambda key: self._should_display(_get_full_key(prev_key, key),
- filter_key))
-
- def _print(self, value, printer, filter_key=None, sort=None, force=False):
+ def _print(self, valuestr: str, printer: prefix.Printer,
+ filter_key: Optional[str] = None,
+ sort=None, force=False):
try:
- value = value.strip().splitlines()
+ value = valuestr.strip().splitlines()
except AttributeError:
pass
@@ -101,7 +78,8 @@ class Printer:
else:
self._print_value(value, printer, filter_key)
- def _print_key(self, key, value, printer, filter_key):
+ def _print_key(self, key: str, value: str,
+ printer: prefix.Printer, filter_key):
sort = None
try:
method = getattr(self, 'cb_print_%s' % key.replace('.', '_'))
@@ -116,20 +94,24 @@ class Printer:
self._print(value, printer, filter_key, sort)
- def _print_value(self, value, printer, filter_key):
+ def _print_value(self, value: str, printer: prefix.Printer,
+ filter_key: Optional[str]):
full_key = _get_full_key(printer.full_key, value)
if self._should_display(full_key, filter_key) and \
filter_key is None or full_key == filter_key:
printer.pprint(value)
- def _print_list(self, values, printer, filter_key):
+ def _print_list(self, values: list[str],
+ printer: prefix.Printer,
+ filter_key: Optional[str]):
for value in values:
if isinstance(value, str):
self._print_value(value, printer, filter_key)
else:
self._print(value, printer, filter_key)
- def _print_dict(self, value, printer, filter_key, sort, force):
+ def _print_dict(self, value: dict[str, str], printer: prefix.Printer,
+ filter_key: Optional[str], sort: list, force: bool):
keys = _sort_with_list(
[(key, full_key) for key in list(value.keys())
for full_key in [_get_full_key(printer.full_key, key)]
@@ -138,7 +120,7 @@ class Printer:
if len(keys) == 0:
return
- maxlength = max([len(self._get_label(key, full_key)) \
+ maxlength = max([len(self._get_label(key, full_key))
for (key, full_key) in keys])
_space(filter_key, printer.full_key, printer, True)
@@ -155,7 +137,8 @@ class Printer:
found = True
new_printer = prefix.Printer(full_key, printer)
if filter_key is None:
- new_printer.set_label(self._get_label(key, full_key), maxlength)
+ new_printer.set_label(self._get_label(key, full_key),
+ maxlength)
self._print_key(full_key, value[key], new_printer, new_filter_key)
_space(filter_key, printer.full_key, new_printer)
@@ -163,12 +146,13 @@ class Printer:
if force and not found:
printer.pprint('')
- def _get_label(self, key, full_key):
+ def _get_label(self, key: str, full_key: str):
if full_key in self.labels:
return self.labels[full_key]
return key
- def info(self, key, label=None, maxlength=0, force=False):
+ def info(self, key: Optional[str], label: Optional[str] = None,
+ maxlength=0, force=False):
printer = prefix.Printer()
if label is not None:
printer.set_label(label, maxlength)