diff --git a/cterasdk/core/devices.py b/cterasdk/core/devices.py index a45470fe..b71450ef 100644 --- a/cterasdk/core/devices.py +++ b/cterasdk/core/devices.py @@ -11,7 +11,7 @@ class Devices(BaseCommand): name_attr = 'name' type_attr = 'deviceType' - default = ['name', 'portal', 'deviceType', 'version', 'remoteAccessUrl'] + default = ['name', 'portal', 'deviceType', 'version', 'remoteAccessUrl', 'deviceDnsName'] def _create_device_resource_uri(self, device_name, tenant): session = self._core.session() diff --git a/cterasdk/core/remote.py b/cterasdk/core/remote.py index 432f7c35..5c7ff757 100644 --- a/cterasdk/core/remote.py +++ b/cterasdk/core/remote.py @@ -1,11 +1,21 @@ +from urllib.parse import urlparse from .enum import DeviceType from ..objects.synchronous import edge, drive -from ..common import parse_base_object_ref + + +def _relay_base(Portal, device): + device_dns = getattr(device, 'deviceDnsName', None) + if device_dns and device_dns.startswith(f'{device.name}.'): + portal_hostname = device_dns[len(device.name) + 1:] + parsed = urlparse(Portal.ctera.baseurl) + port = f':{parsed.port}' if parsed.port not in (None, 80, 443) else '' + base_path = parsed.path.rstrip('/') + return f'{parsed.scheme}://{portal_hostname}{port}{base_path}/devices/{device.name}' + return f'{Portal.ctera.baseurl.rstrip("/")}/devices/{device.name}' def remote_command(Portal, device): - tenant = parse_base_object_ref(device.portal).name - base = f'{Portal.ctera.baseurl}/devicecmdnew/{tenant}/{device.name}' + base = _relay_base(Portal, device) ManagedDevice = None if device.deviceType in DeviceType.Gateways: diff --git a/cterasdk/edge/__init__.py b/cterasdk/edge/__init__.py index be39ec2b..0b2d9c3c 100644 --- a/cterasdk/edge/__init__.py +++ b/cterasdk/edge/__init__.py @@ -30,6 +30,7 @@ 'shares', 'shell', 'smb', + 'stats', 'support', 'sync', 'syslog', diff --git a/cterasdk/edge/stats.py b/cterasdk/edge/stats.py new file mode 100644 index 00000000..b75e739c --- /dev/null +++ b/cterasdk/edge/stats.py @@ -0,0 +1,20 @@ +import logging + +from .base_command import BaseCommand + + +logger = logging.getLogger('cterasdk.edge') + + +VALID_STAT_TYPES = ('cpu', 'memory', 'cache', 'volume', 'connections', 'local_io', 'disk_io', 'cloud_io') +VALID_INTERVALS = ('hour', 'day', 'week', 'month', 'year', 'last') + + +class Stats(BaseCommand): + + def get(self, stat_type, interval='hour'): + if stat_type not in VALID_STAT_TYPES: + raise ValueError(f'Invalid stat_type {stat_type!r}. Valid: {VALID_STAT_TYPES}') + if interval not in VALID_INTERVALS: + raise ValueError(f'Invalid interval {interval!r}. Valid: {VALID_INTERVALS}') + return self._edge.api.get(f'/stats/{stat_type}', params={'interval': interval}) diff --git a/cterasdk/objects/synchronous/drive.py b/cterasdk/objects/synchronous/drive.py index 2ab59a78..f39adc4d 100644 --- a/cterasdk/objects/synchronous/drive.py +++ b/cterasdk/objects/synchronous/drive.py @@ -4,18 +4,20 @@ from ..endpoints import EndpointBuilder from ...lib.session.edge import Session from ...edge import backup, cli, logs, services, support, sync +from .remote_clients import RemoteClients -class Clients: +class Clients(RemoteClients): def __init__(self, drive, Portal): if Portal: drive._Portal = Portal drive.default.close() drive._ctera_session.start_remote_session(Portal.session()) - self.api = Portal.default.clone(clients.API, EndpointBuilder.new(drive.base), authenticator=lambda *_: True) + api_client = Portal.default.clone(clients.API, EndpointBuilder.new(drive.base, '/admingui/api'), authenticator=lambda *_: True) else: - self.api = drive.default.clone(clients.API, EndpointBuilder.new(drive.base, '/admingui/api')) + api_client = drive.default.clone(clients.API, EndpointBuilder.new(drive.base, '/admingui/api')) + super().__init__(drive, Portal, api_client) class Drive(Management): diff --git a/cterasdk/objects/synchronous/edge.py b/cterasdk/objects/synchronous/edge.py index dbe2c9ba..ad84d960 100644 --- a/cterasdk/objects/synchronous/edge.py +++ b/cterasdk/objects/synchronous/edge.py @@ -5,29 +5,29 @@ from .. import authenticators from ...common import modules from ...lib.session.edge import Session - - +from .remote_clients import RemoteClients from ...edge import ( afp, aio, antivirus, array, audit, backup, cache, cli, config, connection, ctera_migrate, dedup, directoryservice, drive, files, firmware, ftp, groups, licenses, login, logs, mail, network, nfs, ntp, power, remote, rsync, ransom_protect, services, - shares, shell, smb, snmp, ssh, ssl, support, sync, syslog, tasks, telnet, + shares, shell, smb, snmp, ssh, ssl, stats, support, sync, syslog, tasks, telnet, timezone, users, volumes, ) -class Clients: +class Clients(RemoteClients): def __init__(self, edge, Portal): if Portal: edge._Portal = Portal edge.default.close() edge._ctera_session.start_remote_session(Portal.session()) - self.api = Portal.default.clone(clients.API, EndpointBuilder.new(edge.base), authenticator=lambda *_: True) + api_client = Portal.default.clone(clients.API, EndpointBuilder.new(edge.base, '/admingui/api'), authenticator=lambda *_: True) else: self.migrate = edge.default.clone(clients.Migrate, EndpointBuilder.new(edge.base, '/migration/rest/v1')) - self.api = edge.default.clone(clients.API, EndpointBuilder.new(edge.base, '/admingui/api')) + api_client = edge.default.clone(clients.API, EndpointBuilder.new(edge.base, '/admingui/api')) self.io = IO(edge) + super().__init__(edge, Portal, api_client) class IO: @@ -106,6 +106,7 @@ def __init__(self, host=None, port=None, https=True, Portal=None, *, base=None): self.shell = shell.Shell(self) self.smb = smb.SMB(self) self.snmp = snmp.SNMP(self) + self.stats = stats.Stats(self) self.ssh = ssh.SSH(self) self.ssl = modules.initialize(ssl.SSLModule, self) self.support = support.Support(self) @@ -164,5 +165,5 @@ def _omit_fields(self): return super()._omit_fields + ['afp', 'aio', 'array', 'audit', 'antivirus', 'backup', 'cache', 'cli', 'config', 'ctera_migrate', 'dedup', 'directoryservice', 'drive', 'files', 'firmware', 'ftp', 'groups', 'licenses', 'logs', 'mail', 'network', 'nfs', 'ntp', 'power', 'ransom_protect', 'rsync', 'services', 'shares', 'shell', - 'smb', 'snmp', 'ssh', 'ssl', 'support', 'sync', 'syslog', 'tasks', 'telnet', 'timezone', + 'smb', 'snmp', 'ssh', 'ssl', 'stats', 'support', 'sync', 'syslog', 'tasks', 'telnet', 'timezone', 'users', 'volumes'] diff --git a/cterasdk/objects/synchronous/remote_clients.py b/cterasdk/objects/synchronous/remote_clients.py new file mode 100644 index 00000000..744e8e49 --- /dev/null +++ b/cterasdk/objects/synchronous/remote_clients.py @@ -0,0 +1,28 @@ +import logging +from ...common import parse_base_object_ref +from ...exceptions import CTERAException + + +logger = logging.getLogger('cterasdk.remote') + + +class RemoteClients: + + def __init__(self, device, Portal, api_client): + self._device = device + self._Portal = Portal + self._authenticated = False + self._api = api_client + + @property + def api(self): + if self._Portal and not self._authenticated: + tenant = parse_base_object_ref(self._device.portal).name + device_name = self._device.name + logger.debug('Auto-SSO login via relay channel. %s', {'tenant': tenant, 'device': device_name}) + token = self._Portal.api.execute(f'/portals/{tenant}/devices/{device_name}', 'singleSignOn') + if not token: + raise CTERAException('Failed to Retrieve SSO Ticket.') + self._api.get('/ssologin', params={'ticket': token}) + self._authenticated = True + return self._api diff --git a/tests/ut/core/admin/test_relay_base.py b/tests/ut/core/admin/test_relay_base.py new file mode 100644 index 00000000..d1e41ded --- /dev/null +++ b/tests/ut/core/admin/test_relay_base.py @@ -0,0 +1,80 @@ +import unittest +from unittest import mock + +from cterasdk.core.remote import _relay_base + + +class TestRelayBase(unittest.TestCase): + + def _make_portal(self, baseurl): + portal = mock.MagicMock() + portal.ctera.baseurl = baseurl + return portal + + def _make_device(self, name, device_dns_name=None): + device = mock.MagicMock(spec=['name', 'deviceDnsName'] if device_dns_name is not None else ['name']) + device.name = name + if device_dns_name is not None: + device.deviceDnsName = device_dns_name + return device + + def test_fallback_when_device_dns_is_none(self): + portal = self._make_portal('https://portal.ctera.me') + device = self._make_device('vGateway-7192') + result = _relay_base(portal, device) + self.assertEqual(result, 'https://portal.ctera.me/devices/vGateway-7192') + + def test_fallback_when_device_dns_does_not_start_with_name(self): + portal = self._make_portal('https://portal.ctera.me') + device = self._make_device('vGateway-7192', 'other-device.portal.ctera.me') + result = _relay_base(portal, device) + self.assertEqual(result, 'https://portal.ctera.me/devices/vGateway-7192') + + def test_hostname_derivation_from_dns_name(self): + portal = self._make_portal('https://10.0.0.1') + device = self._make_device('vGateway-7192', 'vGateway-7192.portal.ctera.me') + result = _relay_base(portal, device) + self.assertEqual(result, 'https://portal.ctera.me/devices/vGateway-7192') + + def test_non_standard_port_preserved(self): + portal = self._make_portal('https://10.0.0.1:8443') + device = self._make_device('vGateway-7192', 'vGateway-7192.portal.ctera.me') + result = _relay_base(portal, device) + self.assertEqual(result, 'https://portal.ctera.me:8443/devices/vGateway-7192') + + def test_standard_https_port_omitted(self): + portal = self._make_portal('https://10.0.0.1:443') + device = self._make_device('vGateway-7192', 'vGateway-7192.portal.ctera.me') + result = _relay_base(portal, device) + self.assertEqual(result, 'https://portal.ctera.me/devices/vGateway-7192') + + def test_standard_http_port_omitted(self): + portal = self._make_portal('http://10.0.0.1:80') + device = self._make_device('vGateway-7192', 'vGateway-7192.portal.ctera.me') + result = _relay_base(portal, device) + self.assertEqual(result, 'http://portal.ctera.me/devices/vGateway-7192') + + def test_trailing_slash_in_baseurl_no_double_slash(self): + portal = self._make_portal('https://portal.ctera.me/') + device = self._make_device('vGateway-7192', 'vGateway-7192.portal.ctera.me') + result = _relay_base(portal, device) + self.assertNotIn('//', result.split('://')[1]) + self.assertEqual(result, 'https://portal.ctera.me/devices/vGateway-7192') + + def test_baseurl_with_path(self): + portal = self._make_portal('https://10.0.0.1/api/v1') + device = self._make_device('vGateway-7192', 'vGateway-7192.portal.ctera.me') + result = _relay_base(portal, device) + self.assertEqual(result, 'https://portal.ctera.me/api/v1/devices/vGateway-7192') + + def test_substring_device_name_does_not_false_match(self): + portal = self._make_portal('https://portal.ctera.me') + device = self._make_device('gw', 'other-gw.portal.ctera.me') + result = _relay_base(portal, device) + self.assertEqual(result, 'https://portal.ctera.me/devices/gw') + + def test_fallback_strips_trailing_slash(self): + portal = self._make_portal('https://portal.ctera.me/') + device = self._make_device('vGateway-7192') + result = _relay_base(portal, device) + self.assertEqual(result, 'https://portal.ctera.me/devices/vGateway-7192') diff --git a/tests/ut/core/admin/test_remote.py b/tests/ut/core/admin/test_remote.py index 12e05d55..a5d28b4c 100644 --- a/tests/ut/core/admin/test_remote.py +++ b/tests/ut/core/admin/test_remote.py @@ -73,6 +73,33 @@ def _create_device_param(name, portal, device_type, remote_access_url): param.remoteAccessUrl = remote_access_url return param + def _setup_remote_device_with_sso(self): + remote_session = self.patch_call("cterasdk.lib.session.edge.Session.start_remote_session") + remote_session.return_value = munch.Munch({'account': munch.Munch({'name': 'mickey', 'tenant': 'tenant'})}) + get_multi_response = TestCoreRemote._create_device_param(self._device_name, self._device_portal, + 'vGateway', self._device_remote_access_url) + self._init_global_admin(get_multi_response=get_multi_response, execute_response=self._sso_ticket) + self._activate_portal_session() + device = devices.Devices(self._global_admin).device(self._device_name) + device._ctera_clients._api = mock.MagicMock() + return device + + def test_auto_sso_on_first_api_access(self): + device = self._setup_remote_device_with_sso() + _ = device.api + self._global_admin.api.execute.assert_called_once_with( + f'/portals/{self._tenant_name}/devices/{self._device_name}', 'singleSignOn') + device._ctera_clients._api.get.assert_called_once_with('/ssologin', params={'ticket': self._sso_ticket}) + + def test_auto_sso_not_repeated_on_subsequent_api_access(self): + device = self._setup_remote_device_with_sso() + _ = device.api + _ = device.api + _ = device.api + self._global_admin.api.execute.assert_called_once_with( + f'/portals/{self._tenant_name}/devices/{self._device_name}', 'singleSignOn') + device._ctera_clients._api.get.assert_called_once_with('/ssologin', params={'ticket': self._sso_ticket}) + @staticmethod def _create_current_session_object(): session = Object() diff --git a/tests/ut/edge/test_stats.py b/tests/ut/edge/test_stats.py new file mode 100644 index 00000000..73d5f7fa --- /dev/null +++ b/tests/ut/edge/test_stats.py @@ -0,0 +1,37 @@ +from cterasdk.edge import stats +from tests.ut.edge import base_edge + + +class TestEdgeStats(base_edge.BaseEdgeTest): + + def setUp(self): + super().setUp() + self._init_filer() + + def test_get_cpu_default_interval(self): + stats.Stats(self._filer).get('cpu') + self._filer.api.get.assert_called_with('/stats/cpu', params={'interval': 'hour'}) + + def test_get_memory_with_interval(self): + stats.Stats(self._filer).get('memory', interval='day') + self._filer.api.get.assert_called_with('/stats/memory', params={'interval': 'day'}) + + def test_get_all_stat_types(self): + for stat_type in stats.VALID_STAT_TYPES: + self._filer.api.get.reset_mock() + stats.Stats(self._filer).get(stat_type, interval='hour') + self._filer.api.get.assert_called_with(f'/stats/{stat_type}', params={'interval': 'hour'}) + + def test_get_all_intervals(self): + for interval in stats.VALID_INTERVALS: + self._filer.api.get.reset_mock() + stats.Stats(self._filer).get('cpu', interval=interval) + self._filer.api.get.assert_called_with('/stats/cpu', params={'interval': interval}) + + def test_invalid_stat_type_raises_value_error(self): + with self.assertRaises(ValueError): + stats.Stats(self._filer).get('invalid_type') + + def test_invalid_interval_raises_value_error(self): + with self.assertRaises(ValueError): + stats.Stats(self._filer).get('cpu', interval='invalid_interval')