Refactor & add unittests for almost all functionality, add tox for running tests on py27 and py35 (#17)

* Refactor & add unittests for almost all functionality, add tox for running tests on py27 and py35

This commit adds unit tests for current api functionality.
- currently no mocking, all tests are run on the device.
- the library is now compatible with python 2.7 and python 3.5, use tox for tests
- schema checks are done with voluptuous

refactoring:
- protocol is separated into its own file, smartplug adapted to receive protocol worker as parameter.
- cleaned up the initialization routine, initialization is done on use, not on creation of smartplug
- added model and features properties, identity kept for backwards compatibility
- no more storing of local variables outside _sys_info, paves a way to handle state changes sanely (without complete reinitialization)

* Fix CI warnings, remove unused leftover code

* Rename _initialize to _fetch_sysinfo, as that's what it does.

* examples.cli: fix identify call, prettyprint sysinfo, update readme which had false format for led setting

* Add tox-travis for automated testing.
This commit is contained in:
Teemu R 2016-12-16 23:51:56 +01:00 committed by GadgetReactor
parent 45fc354888
commit fd4e363f56
9 changed files with 439 additions and 129 deletions

9
.travis.yml Normal file
View File

@ -0,0 +1,9 @@
sudo: false
language: python
python:
- "2.7"
- "3.4"
- "3.5"
install: pip install tox-travis
script: tox

View File

@ -49,7 +49,7 @@ print("Per month: %s" % plug.get_emeter_monthly(year=2016))
## Switching the led
```python
print("Current LED state: %s" % plug.led)
plug.led = 0 # turn off led
plug.led = False # turn off led
print("New LED state: %s" % plug.led)
```

View File

@ -1,5 +1,6 @@
import sys
import logging
from pprint import pformat as pf
from pyHS100 import SmartPlug
@ -11,8 +12,8 @@ if len(sys.argv) < 2:
hs = SmartPlug(sys.argv[1])
logging.info("Identify: %s", hs.identify)
logging.info("Sysinfo: %s", hs.get_sysinfo())
logging.info("Identify: %s", hs.identify())
logging.info("Sysinfo: %s", pf(hs.get_sysinfo()))
has_emeter = hs.has_emeter
if has_emeter:
logging.info("== Emeter ==")

View File

@ -1 +1,3 @@
from pyHS100.pyHS100 import SmartPlug
from __future__ import absolute_import
from __future__ import unicode_literals
from pyHS100.pyHS100 import SmartPlug, SmartPlugException

100
pyHS100/protocol.py Normal file
View File

@ -0,0 +1,100 @@
from __future__ import absolute_import
from __future__ import unicode_literals
import json
import socket
import logging
_LOGGER = logging.getLogger(__name__)
class TPLinkSmartHomeProtocol:
"""
Implementation of the TP-Link Smart Home Protocol
Encryption/Decryption methods based on the works of
Lubomir Stroetmann and Tobias Esser
https://www.softscheck.com/en/reverse-engineering-tp-link-hs110/
https://github.com/softScheck/tplink-smartplug/
which are licensed under the Apache License, Version 2.0
http://www.apache.org/licenses/LICENSE-2.0
"""
initialization_vector = 171
@staticmethod
def query(host, request, port=9999):
"""
Request information from a TP-Link SmartHome Device and return the
response.
:param str host: ip address of the device
:param int port: port on the device (default: 9999)
:param request: command to send to the device (can be either dict or
json string)
:return:
"""
if isinstance(request, dict):
request = json.dumps(request)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((host, port))
_LOGGER.debug("> (%i) %s", len(request), request)
sock.send(TPLinkSmartHomeProtocol.encrypt(request))
buffer = bytes()
while True:
chunk = sock.recv(4096)
buffer += chunk
if not chunk:
break
sock.shutdown(socket.SHUT_RDWR)
sock.close()
response = TPLinkSmartHomeProtocol.decrypt(buffer[4:])
_LOGGER.debug("< (%i) %s", len(response), response)
return json.loads(response)
@staticmethod
def encrypt(request):
"""
Encrypt a request for a TP-Link Smart Home Device.
:param request: plaintext request data
:return: ciphertext request
"""
key = TPLinkSmartHomeProtocol.initialization_vector
buffer = bytearray(4) # 4 nullbytes
for char in request:
cipher = key ^ ord(char)
key = cipher
buffer.append(cipher)
return buffer
@staticmethod
def decrypt(ciphertext):
"""
Decrypt a response of a TP-Link Smart Home Device.
:param ciphertext: encrypted response data
:return: plaintext response
"""
key = TPLinkSmartHomeProtocol.initialization_vector
buffer = []
ciphertext = ciphertext.decode('latin-1')
for char in ciphertext:
plain = key ^ ord(char)
key = ord(char)
buffer.append(chr(plain))
plaintext = ''.join(buffer)
return plaintext

View File

@ -14,11 +14,14 @@ You may obtain a copy of the license at
http://www.apache.org/licenses/LICENSE-2.0
"""
from __future__ import absolute_import
from __future__ import unicode_literals
import datetime
import json
import logging
import socket
import sys
from pyHS100.protocol import TPLinkSmartHomeProtocol
_LOGGER = logging.getLogger(__name__)
@ -60,7 +63,7 @@ class SmartPlug:
ALL_FEATURES = (FEATURE_ENERGY_METER, FEATURE_TIMER)
def __init__(self, ip_address):
def __init__(self, ip_address, protocol=TPLinkSmartHomeProtocol):
"""
Create a new SmartPlug instance, identified through its IP address.
@ -69,20 +72,18 @@ class SmartPlug:
"""
socket.inet_pton(socket.AF_INET, ip_address)
self.ip_address = ip_address
self.protocol = protocol
self._sys_info = None
self.initialize()
def initialize(self):
def _fetch_sysinfo(self):
"""
(Re-)Initializes the state.
Fetches the system information from the device.
This should be called when the state of the plug is changed anyway.
This should be called when the state of the plug is changed.
:raises: SmartPlugException: on error
"""
self.sys_info = self.get_sysinfo()
self._alias, self.model, self.features = self.identify()
self._sys_info = self.get_sysinfo()
def _query_helper(self, target, cmd, arg={}):
"""
@ -95,19 +96,31 @@ class SmartPlug:
:rtype: dict
:raises SmartPlugException: if command was not executed correctly
"""
response = TPLinkSmartHomeProtocol.query(
host=self.ip_address,
request={target: {cmd: arg}}
)
result = response[target][cmd]
if result["err_code"] != 0:
try:
response = self.protocol.query(
host=self.ip_address,
request={target: {cmd: arg}}
)
except Exception as ex:
raise SmartPlugException(ex)
result = response[target]
if "err_code" in result and result["err_code"] != 0:
raise SmartPlugException("Error on {}.{}: {}".format(target, cmd, result))
result = result[cmd]
del result["err_code"]
return result
@property
def sys_info(self):
if not self._sys_info:
self._fetch_sysinfo()
return self._sys_info
@property
def state(self):
"""
@ -141,14 +154,16 @@ class SmartPlug:
:raises SmartPlugException: on error
"""
if value.upper() == SmartPlug.SWITCH_STATE_ON:
if not isinstance(value, str):
raise ValueError("State must be str, not of %s.", type(value))
elif value.upper() == SmartPlug.SWITCH_STATE_ON:
self.turn_on()
elif value.upper() == SmartPlug.SWITCH_STATE_OFF:
self.turn_off()
else:
raise ValueError("State %s is not valid.", value)
self.initialize()
self._fetch_sysinfo()
def get_sysinfo(self):
"""
@ -187,7 +202,7 @@ class SmartPlug:
"""
self._query_helper("system", "set_relay_state", {"state": 1})
self.initialize()
self._fetch_sysinfo()
def turn_off(self):
"""
@ -197,7 +212,7 @@ class SmartPlug:
"""
self._query_helper("system", "set_relay_state", {"state": 0})
self.initialize()
self._fetch_sysinfo()
@property
def has_emeter(self):
@ -282,7 +297,7 @@ class SmartPlug:
self._query_helper("emeter", "erase_emeter_stat", None)
self.initialize()
self._fetch_sysinfo()
# As query_helper raises exception in case of failure, we have succeeded when we are this far.
return True
@ -309,16 +324,35 @@ class SmartPlug:
:return: (alias, model, list of supported features)
:rtype: tuple
"""
alias = self.sys_info['alias']
model = self.sys_info['model']
return self.alias, self.model, self.features
@property
def model(self):
"""
Get model of the device
:return: device model
:rtype: str
:raises SmartPlugException: on error
"""
return self.sys_info['model']
@property
def features(self):
"""
Returns features of the devices
:return: list of features
:rtype: list
"""
features = self.sys_info['feature'].split(':')
for feature in features:
if feature not in SmartPlug.ALL_FEATURES:
_LOGGER.warning("Unknown feature %s on device %s.",
feature, model)
feature, self.model)
return alias, model, features
return features
@property
def alias(self):
@ -328,7 +362,7 @@ class SmartPlug:
:return: Device name aka alias.
:rtype: str
"""
return self._alias
return self.sys_info['alias']
@alias.setter
def alias(self, alias):
@ -340,7 +374,7 @@ class SmartPlug:
"""
self._query_helper("system", "set_dev_alias", {"alias": alias})
self.initialize()
self._fetch_sysinfo()
@property
def led(self):
@ -362,7 +396,7 @@ class SmartPlug:
"""
self._query_helper("system", "set_led_off", {"off": int(not state)})
self.initialize()
self._fetch_sysinfo()
@property
def icon(self):
@ -510,102 +544,7 @@ class SmartPlug:
"""
self._query_helper("system", "set_mac_addr", {"mac": mac})
self.initialize()
self._fetch_sysinfo()
class TPLinkSmartHomeProtocol:
"""
Implementation of the TP-Link Smart Home Protocol
Encryption/Decryption methods based on the works of
Lubomir Stroetmann and Tobias Esser
https://www.softscheck.com/en/reverse-engineering-tp-link-hs110/
https://github.com/softScheck/tplink-smartplug/
which are licensed under the Apache License, Version 2.0
http://www.apache.org/licenses/LICENSE-2.0
"""
initialization_vector = 171
@staticmethod
def query(host, request, port=9999):
"""
Request information from a TP-Link SmartHome Device and return the
response.
:param str host: ip address of the device
:param int port: port on the device (default: 9999)
:param request: command to send to the device (can be either dict or
json string)
:return:
"""
if isinstance(request, dict):
request = json.dumps(request)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((host, port))
_LOGGER.debug("> (%i) %s", len(request), request)
sock.send(TPLinkSmartHomeProtocol.encrypt(request))
buffer = bytes()
while True:
chunk = sock.recv(4096)
buffer += chunk
if not chunk:
break
sock.shutdown(socket.SHUT_RDWR)
sock.close()
response = TPLinkSmartHomeProtocol.decrypt(buffer[4:])
_LOGGER.debug("< (%i) %s", len(response), response)
return json.loads(response)
@staticmethod
def encrypt(request):
"""
Encrypt a request for a TP-Link Smart Home Device.
:param request: plaintext request data
:return: ciphertext request
"""
key = TPLinkSmartHomeProtocol.initialization_vector
buffer = ['\0\0\0\0']
for char in request:
cipher = key ^ ord(char)
key = cipher
buffer.append(chr(cipher))
ciphertext = ''.join(buffer)
if sys.version_info.major > 2:
ciphertext = ciphertext.encode('latin-1')
return ciphertext
@staticmethod
def decrypt(ciphertext):
"""
Decrypt a response of a TP-Link Smart Home Device.
:param ciphertext: encrypted response data
:return: plaintext response
"""
key = TPLinkSmartHomeProtocol.initialization_vector
buffer = []
if sys.version_info.major > 2:
ciphertext = ciphertext.decode('latin-1')
for char in ciphertext:
plain = key ^ ord(char)
key = ord(char)
buffer.append(chr(plain))
plaintext = ''.join(buffer)
return plaintext

View File

@ -0,0 +1,15 @@
from __future__ import absolute_import
from __future__ import unicode_literals
from unittest import TestCase
from pyHS100.protocol import TPLinkSmartHomeProtocol
import json
class TestTPLinkSmartHomeProtocol(TestCase):
def test_encrypt(self):
d = json.dumps({'foo': 1, 'bar': 2})
encrypted = TPLinkSmartHomeProtocol.encrypt(d)
# encrypt appends nullbytes for the protocol sends
encrypted = encrypted.lstrip(b'\0')
self.assertEqual(d, TPLinkSmartHomeProtocol.decrypt(encrypted))

View File

@ -0,0 +1,231 @@
from __future__ import absolute_import
from __future__ import unicode_literals
from unittest import TestCase, skip, skipIf
from voluptuous import Schema, Invalid, All, Range
from functools import partial
import datetime
import re
from pyHS100 import SmartPlug, SmartPlugException
PLUG_IP = '192.168.250.186'
SKIP_STATE_TESTS = True
# python2 compatibility
try:
basestring
except NameError:
basestring = str
def check_int_bool(x):
if x != 0 and x != 1:
raise Invalid(x)
return x
def check_mac(x):
if re.match("[0-9a-f]{2}([-:])[0-9a-f]{2}(\\1[0-9a-f]{2}){4}$", x.lower()):
return x
raise Invalid(x)
def check_mode(x):
if x in ['schedule']:
return x
raise Invalid("invalid mode {}".format(x))
class TestSmartPlug(TestCase):
sysinfo_schema = Schema({
'active_mode': check_mode,
'alias': basestring,
'dev_name': basestring,
'deviceId': basestring,
'feature': basestring,
'fwId': basestring,
'hwId': basestring,
'hw_ver': basestring,
'icon_hash': basestring,
'latitude': All(float, Range(min=-90, max=90)),
'led_off': check_int_bool,
'longitude': All(float, Range(min=-180, max=180)),
'mac': check_mac,
'model': basestring,
'oemId': basestring,
'on_time': int,
'relay_state': int,
'rssi': All(int, Range(max=0)),
'sw_ver': basestring,
'type': basestring,
'updating': check_int_bool,
})
current_consumption_schema = Schema({
'voltage': All(float, Range(min=0, max=300)),
'power': All(float, Range(min=0)),
'total': All(float, Range(min=0)),
'current': All(float, Range(min=0)),
})
tz_schema = Schema({
'zone_str': basestring,
'dst_offset': int,
'index': All(int, Range(min=0)),
'tz_str': basestring,
})
def setUp(self):
self.plug = SmartPlug(PLUG_IP)
def tearDown(self):
self.plug = None
def test_initialize(self):
self.assertIsNotNone(self.plug.sys_info)
self.sysinfo_schema(self.plug.sys_info)
def test_initialize_invalid_connection(self):
plug = SmartPlug('127.0.0.1')
with self.assertRaises(SmartPlugException):
plug.sys_info['model']
def test_query_helper(self):
with self.assertRaises(SmartPlugException):
self.plug._query_helper("test", "testcmd", {})
# TODO check for unwrapping?
@skipIf(SKIP_STATE_TESTS, "SKIP_STATE_TESTS is True, skipping")
def test_state(self):
def set_invalid(x):
self.plug.state = x
set_invalid_int = partial(set_invalid, 1234)
self.assertRaises(ValueError, set_invalid_int)
set_invalid_str = partial(set_invalid, "1234")
self.assertRaises(ValueError, set_invalid_str)
set_invalid_bool = partial(set_invalid, True)
self.assertRaises(ValueError, set_invalid_bool)
orig_state = self.plug.state
if orig_state == SmartPlug.SWITCH_STATE_OFF:
self.plug.state = "ON"
self.assertTrue(self.plug.state == SmartPlug.SWITCH_STATE_ON)
self.plug.state = "OFF"
self.assertTrue(self.plug.state == SmartPlug.SWITCH_STATE_OFF)
elif orig_state == SmartPlug.SWITCH_STATE_ON:
self.plug.state = "OFF"
self.assertTrue(self.plug.state == SmartPlug.SWITCH_STATE_OFF)
self.plug.state = "ON"
self.assertTrue(self.plug.state == SmartPlug.SWITCH_STATE_ON)
elif orig_state == SmartPlug.SWITCH_STATE_UNKNOWN:
self.fail("can't test for unknown state")
def test_get_sysinfo(self):
# initialize checks for this already, but just to be sure
self.sysinfo_schema(self.plug.get_sysinfo())
@skipIf(SKIP_STATE_TESTS, "SKIP_STATE_TESTS is True, skipping")
def test_turns_and_isses(self):
orig_state = self.plug.is_on
if orig_state:
self.plug.turn_off()
self.assertFalse(self.plug.is_on)
self.assertTrue(self.plug.is_off)
self.plug.turn_on()
self.assertTrue(self.plug.is_on)
else:
self.plug.turn_on()
self.assertFalse(self.plug.is_off)
self.assertTrue(self.plug.is_on)
self.plug.turn_off()
self.assertTrue(self.plug.is_off)
def test_has_emeter(self):
# a not so nice way for checking for emeter availability..
if "110" in self.plug.sys_info["model"]:
self.assertTrue(self.plug.has_emeter)
else:
self.assertFalse(self.plug.has_emeter)
def test_get_emeter_realtime(self):
self.current_consumption_schema((self.plug.get_emeter_realtime()))
def test_get_emeter_daily(self):
self.assertEqual(self.plug.get_emeter_daily(year=1900, month=1), {})
k, v = self.plug.get_emeter_daily().popitem()
self.assertTrue(isinstance(k, int))
self.assertTrue(isinstance(v, float))
def test_get_emeter_monthly(self):
self.assertEqual(self.plug.get_emeter_monthly(year=1900), {})
d = self.plug.get_emeter_monthly()
k, v = d.popitem()
self.assertTrue(isinstance(k, int))
self.assertTrue(isinstance(v, float))
@skip("not clearing your stats..")
def test_erase_emeter_stats(self):
self.fail()
def test_current_consumption(self):
x = self.plug.current_consumption()
self.assertTrue(isinstance(x, float))
self.assertTrue(x >= 0.0)
def test_identify(self):
ident = self.plug.identify()
self.assertTrue(isinstance(ident, tuple))
self.assertTrue(len(ident) == 3)
def test_alias(self):
test_alias = "TEST1234"
original = self.plug.alias
self.assertTrue(isinstance(original, basestring))
self.plug.alias = test_alias
self.assertEqual(self.plug.alias, test_alias)
self.plug.alias = original
self.assertEqual(self.plug.alias, original)
def test_led(self):
original = self.plug.led
self.plug.led = False
self.assertFalse(self.plug.led)
self.plug.led = True
self.assertTrue(self.plug.led)
self.plug.led = original
def test_icon(self):
self.assertEqual(set(self.plug.icon.keys()), {'icon', 'hash'})
def test_time(self):
self.assertTrue(isinstance(self.plug.time, datetime.datetime))
# TODO check setting?
def test_timezone(self):
self.tz_schema(self.plug.timezone)
def test_hw_info(self):
self.sysinfo_schema(self.plug.hw_info)
def test_on_since(self):
self.assertTrue(isinstance(self.plug.on_since, datetime.datetime))
def test_location(self):
self.sysinfo_schema(self.plug.location)
def test_rssi(self):
self.sysinfo_schema({'rssi': self.plug.rssi}) # wrapping for vol
def test_mac(self):
self.sysinfo_schema({'mac': self.plug.mac}) # wrapping for val
# TODO check setting?

13
tox.ini Normal file
View File

@ -0,0 +1,13 @@
[tox]
envlist=py27,py34,py35
[tox:travis]
2.7 = py27
3.4 = py34
3.5 = py35
[testenv]
deps=
future
pytest
voluptuous
commands=py.test