Device Tracking¶
Now that we have uploaded our own keys to the server, we need to get the keys
for other devices in order to encrypt messages for them. When we encrypt an
event for a room, we must know the public keys for all the devices that are in
the room so that they can decrypt the event. We know what users are in the
room, because we can track the m.room.member state
events. From there, we can call the POST /keys/query endpoint to
get users’ devices and the devices’ public keys.
Every time we send an encrypted event to a room, we need to ensure that our
device list is up to date so that new devices will be able to decrypt the
event, and so that we won’t try to send keys to old devices. It would be
inefficient to have to query the device keys of every user in the room, every
time we send a event. In rooms with many users, this could take a long time.
Instead, the server will tell us, via the device_lists property of the GET /sync response, when a user’s devices have changed. We can then keep
track of the users who have device changes and only need to query the users who
have had changes.
The device_lists property of the GET /sync response is an object with two
properties:
changed, which lists users who have updated their device keys (including adding or removing devices), or who now share an encrypted room with the user since the last call toGET /sync; andleft, which lists the users who no longer share an encrypted room with the user since the last call toGET /sync.
"device_lists": schema.Optional(
{
"changed": schema.Optional(list[str]),
"left": schema.Optional(list[str]),
}
),
When the sync processing loop of the Client class encounters this in the
sync, we will publish a message that contains the contents of the
device_lists property.
if "device_lists" in body:
await self.publisher.publish(
DeviceChanges(
body["device_lists"].get("changed", []),
body["device_lists"].get("left", []),
)
)
class DeviceChanges(typing.NamedTuple):
"""A message indicating that user's devices have changed"""
changed: list[str]
left: list[str]
DeviceChanges.changed.__doc__ = "Users whose devices have changed"
DeviceChanges.left.__doc__ = "Users who no longer share a room"
We will now create a DeviceTracker class that we can query for users’ device
keys so that we can encrypt for their devices. It will use the client’s
storage to cache devices that we’ve previously queried, and it will subscribe
to the DeviceChanges messages to know what users it needs to re-fetch devices
for.
For this class, we will start with a simple version, and then extend it to handle some edge cases. We will also handle cross-signing, which we will explain later on (FIXME: link). Thus the code chunks presented in this section may include references to chunks defined in other sections. These can be safely ignored for now.
class DeviceTracker:
"""Tracks user devices to reduce the number of queries required"""
{{DeviceTracker class methods}}
def __init__(self, c: client.Client):
"""
Arguments:
``c``:
the client object
"""
self.client = c
c.publisher.subscribe(client.DeviceChanges, self._subscriber)
{{DeviceTracker initialization}}
We will create a lock (also called a mutex) to ensure that concurrent calls to
the DeviceTracker methods do not conflict in accessing the shared data.
self.lock = asyncio.Lock()
We will keep track of the users that we are getting updates for (we add users
when they show up in the changed list, and remove them when they show up in
the left list). This will help us determine whether to cache results: if we
are not getting updates for a user, then we should not cache the results for
that user. Since we won’t be notified when their devices change, we should err
on the safe side and re-fetch their keys the next time we need them. As well,
when a user is listed in changed or left, we will clear their cached data
so that we will re-fetch their device keys the next time we need them.
async def _subscriber(self, changes: client.DeviceChanges) -> None:
async with self.lock:
tracked_users = self.client.storage.get(
"device_tracker.tracked_users",
{self.client.user_id: True}, # We always track our own devices
)
for user in changes.changed:
tracked_users[user] = True
self._delete_user_device_keys(user)
for user in changes.left:
if user in tracked_users:
del tracked_users[user]
self._delete_user_device_keys(user)
self.client.storage["device_tracker.tracked_users"] = tracked_users
def _delete_user_device_keys(self, user):
user_key = f"device_tracker.cache.{user}"
if user_key in self.client.storage:
del self.client.storage[user_key]
{{mark in-flight device key requests}}
(The “mark in-flight device key requests” chunk is for dealing with some edge cases and will be discussed below. It can be ignored for now.)
Now we create a function to get the device keys for users. It will take an argument to force re-downloading all the keys, ignoring the cache. It will also take a timeout parameter which will limit the time that the homeserver will wait for a response from remote servers.
Tradeoff
Rather than querying the device keys for a user when we need them, we can query the device keys immediately when we are notified that the user’s devices have changed. This means that we will be able to return the users’ devices immediately, rather than having to wait for a query from the server. However, this may result in unnecessary requests made to the server, especially if a user’s devices change frequently.
async def get_device_keys(
self,
users: typing.Iterable[str],
force_download=False,
timeout: typing.Optional[int] = None,
) -> dict[str, "UserDeviceKeysResult"]:
"""Get the device keys for the given users.
Arguments:
``users``:
the user IDs to fetch device keys for
``force_download``:
whether to ignore the cache and force downloading of all device keys from
the server
``timeout``:
a timeout in milliseconds for the homeserver to wait for responses from
remote homeservers
Returns a dict mapping the user IDs to a ``UserDeviceKeysResult``.
"""
ret = {}
users_needing_download: dict[str, list] = {}
async with self.lock:
{{load devices from cache}}
{{check in-flight device key requests}}
if users_needing_download != {}:
{{download device keys}}
{{process /keys/query result}}
{{get in-flight device key results}}
return ret
class UserDeviceKeysResult(typing.NamedTuple):
"""The return value of `DeviceTracker.get_device_keys` for a user"""
device_keys: dict[str, dict]
# FIXME: add cross-signing keys
UserDeviceKeysResult.device_keys.__doc__ = "mapping from device ID to device keys"
(Again, the two chunks referring to “in-flight device key requests” are for handling some edge cases described below.)
If the argument to force downloading is set to False, we will try to obtain
the device keys from the cache and keep track of the users that it doesn’t have
the keys for. Our users_needing_download variable, which keeps track of the
users that we need to download keys for, will be formatted in the way that GET /keys/query expects, to save as from converting it later on: it will be an
object mapping from the user ID to an array. The array is a list indicating
the devices that we want the keys for, or an empty array if we want all of a
user’s device keys. Since the latter is what we want, we set all values to the
empty array.
if not force_download:
for user in users:
storage_key = f"device_tracker.cache.{user}"
if storage_key in self.client.storage:
ret[user] = UserDeviceKeysResult(
**self.client.storage[storage_key]
)
else:
users_needing_download[user] = []
else:
users_needing_download = {user: [] for user in users}
We then download the device keys from the server. The response from the server will be an object with several properties. The properties that we will be interested in here are:
device_keysis a an object mapping from user ID to an object mapping device ID to device key.failuresis an object in which the keys are server names indicating that the remote server could not be contacted to provide the user’s keys. The values in the object do not matter. If a server is included infailures, the users on that server will not have any device information given. In this case, our function will return a result indicating that the users do not have any devices, but we will not cache this result so that it will retry later.
The other properties in the response are related to cross-signing, which will be discussed later on.
req_body: dict[str, typing.Any] = {"device_keys": users_needing_download}
if timeout != None:
req_body["timeout"] = timeout
try:
resp = await self.client.authenticated(
self.client.http_session.post,
self.client.url("v3/keys/query"),
json=req_body,
)
status, resp_body = await client.check_response(resp)
schema.ensure_valid(
resp_body,
{
"device_keys": schema.Optional(dict[str, dict[str, dict]]),
"failures": schema.Optional(dict),
# FIXME: cross-signing keys
},
)
except:
{{clean up in-flight device key results}}
raise
(Again, the “clean up in-flight device key results” chunk is for handling some edge cases described below.)
For each user that we queried, we check whether they have an entry in the
device_keys property of the response. If so, we check that the device keys
listed therein are valid (they satisfies the device key schema, and the user ID
and device ID match). We then cache the keys (when applicable) and add them to
the return value.
user_device_keys = resp_body.get("device_keys", {})
failures = resp_body.get("failures", {})
tracked_users = self.client.storage.get("device_tracker.tracked_users", {})
async with self.lock:
for user in users_needing_download.keys():
device_keys = user_device_keys.get(user, {})
# only keep valid device keys
device_keys = {
device_id: device_key
for device_id, device_key in device_keys.items()
if (
schema.is_valid(device_key, DEVICE_KEY_SCHEMA)
and device_key["user_id"] == user
and device_key["device_id"] == device_id
)
}
user_device_info = UserDeviceKeysResult(device_keys)
# FIXME: process cross-signing keys
self._cache_result(user, tracked_users, failures, user_device_info)
ret[user] = user_device_info
{{set result for in-flight device key request}}
DEVICE_KEY_SCHEMA = {
"algorithms": schema.Array(str),
"device_id": str,
"user_id": str,
"keys": schema.Object(typing.Any),
"signatures": schema.Optional(schema.Object(schema.Object(typing.Any))),
}
Todo
we should also make sure the ed25519 key doesn’t change
def _cache_result(self, user, tracked_users, failures, user_device_info):
if self._should_cache_result(user, tracked_users, failures):
self.client.storage[
f"device_tracker.cache.{user}"
] = user_device_info._asdict()
(Again, the “set result for in-flight device key request” chunk is for handling some edge cases described below.)
We create a function that tells us whether the results should be cached. As mentioned above, we don’t cache results from users who aren’t being tracked, or if the user’s homeserver could not be contacted.
def _should_cache_result(
self, user: str, tracked_users: dict, failures: dict
) -> bool:
{{check do not cache flag}}
if user not in tracked_users:
return False
# find the user's server name
split_user_id = user.split(":", 1)
if len(split_user_id) != 2:
# invalid user ID, so don't cache
return False
if split_user_id[1] in failures:
return False
return True
(The “check do not cache flag” chunk is for handling edge cases described below.)
Tests
tests/test_device_tracking.py:¶# {{copyright}}
import asyncio
import aioresponses
import json
import pytest
from matrixlib import client
from matrixlib import devices
{{test device tracking}}
To test the device tracker, we will pre-populate our device cache, and first ensure that when we try to get the device keys, we get the cached values.
@pytest.mark.asyncio
async def test_basic_device_tracking(mock_aioresponse):
async with client.Client(
storage={
"access_token": "anaccesstoken",
"user_id": "@alice:example.org",
"device_id": "ABCDEFG",
"device_tracker.tracked_users": {
"@bob:example.org": True,
"@carol:example.org": True,
},
"device_tracker.cache.@bob:example.org": {
"device_keys": {
"HIJKLMN": {
"algorithms": [],
"device_id": "HIJKLMN",
"keys": {
"curve25519:HIJKLMN": "some+key",
},
"user_id": "@bob:example.org",
},
},
},
"device_tracker.cache.@carol:example.org": {
"device_keys": {
"OPQRSTU": {
"algorithms": [],
"device_id": "OPQRSTU",
"keys": {
"curve25519:OPQRSTU": "some+other+key",
},
"user_id": "@carol:example.org",
},
},
},
},
callbacks={},
base_client_url="https://matrix-client.example.org/_matrix/client/",
) as c:
{{basic device tracking test}}
tracker = devices.DeviceTracker(c)
assert await tracker.get_device_keys(
["@bob:example.org", "@carol:example.org"]
) == {
"@bob:example.org": devices.UserDeviceKeysResult(
{
"HIJKLMN": {
"algorithms": [],
"device_id": "HIJKLMN",
"keys": {
"curve25519:HIJKLMN": "some+key",
},
"user_id": "@bob:example.org",
},
}
),
"@carol:example.org": devices.UserDeviceKeysResult(
{
"OPQRSTU": {
"algorithms": [],
"device_id": "OPQRSTU",
"keys": {
"curve25519:OPQRSTU": "some+other+key",
},
"user_id": "@carol:example.org",
},
},
),
}
We then create a sync response that will indicate that Bob’s devices have been updated.
mock_aioresponse.get(
"https://matrix-client.example.org/_matrix/client/v3/sync?timeout=30000",
status=200,
body=json.dumps(
{
"device_lists": {
"changed": ["@bob:example.org"],
},
"next_batch": "token1",
}
),
headers={
"content-type": "application/json",
},
)
mock_aioresponse.get(
"https://matrix-client.example.org/_matrix/client/v3/sync?since=token1&timeout=30000",
status=200,
body='{"next_batch":"token1"}',
headers={
"content-type": "application/json",
},
repeat=True,
)
def subscriber(msg) -> None:
c.stop_sync()
c.publisher.subscribe((client.DeviceChanges, client.SyncFailed), subscriber)
c.start_sync()
try:
await c.sync_task
except asyncio.CancelledError:
pass
We now ensure that when we try to get the device keys, it refreshes Bob’s keys and returns the new keys. We use a callback in our HTTP request mock handler to ensure that the request body only requests Bob’s keys.
def callback(url, **kwargs):
assert kwargs["json"] == {"device_keys": {"@bob:example.org": []}}
return aioresponses.CallbackResult(
status=200,
body=json.dumps(
{
"device_keys": {
"@bob:example.org": {
"VWXYZAB": {
"algorithms": [],
"device_id": "VWXYZAB",
"keys": {
"curve25519:HIJKLMN": "some+new+key",
},
"user_id": "@bob:example.org",
},
},
},
}
),
headers={
"Content-Type": "application/json",
},
)
mock_aioresponse.post(
"https://matrix-client.example.org/_matrix/client/v3/keys/query",
callback=callback,
)
assert await tracker.get_device_keys(
["@bob:example.org", "@carol:example.org"]
) == {
"@bob:example.org": devices.UserDeviceKeysResult(
{
"VWXYZAB": {
"algorithms": [],
"device_id": "VWXYZAB",
"keys": {
"curve25519:HIJKLMN": "some+new+key",
},
"user_id": "@bob:example.org",
},
},
),
"@carol:example.org": devices.UserDeviceKeysResult(
{
"OPQRSTU": {
"algorithms": [],
"device_id": "OPQRSTU",
"keys": {
"curve25519:OPQRSTU": "some+other+key",
},
"user_id": "@carol:example.org",
},
},
),
}
Edge cases¶
There are a couple of edge cases that we need to deal with. First of all, if
we are querying the server for the device keys for a user, and concurrently,
another call to get_device_keys needs to fetch the device keys for the same
user, we don’t need to query the server for that user again. Instead, we can
just use the result that we get from the first query. (While we could just
make a second request and not worry about efficiency, this raises the question
of how this will affect our caching — if the two results differ, which result
should be cached? This would also complicate handling for the second edge
case. And, as it turns out, the record-keeping that we do to handle this edge
case will also help with handling the second edge case. So while on the
surface it may seem that it is simpler to just make a second request rather
than synchronizing between calls to our function, it actually turns out to be
simpler to do the synchronization.)
The second edge case is that if we are querying the server for the device keys for a user, and concurrently, we get a sync response saying that the user’s devices have changed, we do not know if the result of our server query represents the user’s old devices or their new devices. In this case, we could retry the request. But what do we do if we get another sync response saying that the user’s devices have changed while the second request is in-flight? How many times will we keep retrying? If we don’t limit the number of retries, this could loop infinitely. Instead of retrying, for the sake of simplicity, in our implementation we will simply return the first result, but we will not cache the value so that the next time we request the user’s devices, we will re-query. Even though we may not be obtaining the most up-to-date value, we are still providing a correct response for the time that the request was made. Other implementations could use some sort of retry mechanism, but should take care to ensure that it will return after a reasonable amount of time, even if the server continuously indicates that devices have changed.
To handle the first edge case, when we make a server request to query a user’s
device keys, we will record in our DeviceTracker object that we are querying
that user’s keys. If we get another call to get_device_keys for that user’s
keys, we will note that there is already a request in-flight, and we can wait
for the first request to complete, and get the value from there. We will make
use of Python’s asyncio.Future
class for this, which
represents a variable that will have a value in the future: when we query the
server for a user’s device keys, we will store a Future, and when we receive
the result, we will resolve the Future to the user’s devices. We will store
the Futures in the in_flight member variable, which will be a dict
mapping user IDs to Futures.
self.in_flight: dict[str, asyncio.Future] = {}
We now write the code chunks in our get_device_keys function that we
mentioned above would be used for dealing with the edge cases.
The first chunk occurs before we make our request to the server. It will look
at the users that we were going to request from the server, and see if we have
in-flight requests for those users. If so, it will record the Futures for
those users, and drop them from our request. If not, it will mark those users
as having requests in-flight, since we will be making the request.
loop = asyncio.get_running_loop()
device_futures = {}
for user in users_needing_download.keys():
# If we already have a request in-flight for the user, we can use that
# result instead of re-requesting. Otherwise, record that we will be
# making a request for that user.
if user in self.in_flight:
device_futures[user] = self.in_flight[user]
else:
self.in_flight[user] = loop.create_future()
for user in device_futures.keys():
# drop users from our request if we're using the result from a future
del users_needing_download[user]
The second chunk occurs after we have made our request to the server. If the
request raises an exception, we will set that exception on our Futures, so
that the exception gets passed to anything that was waiting on our result.
e = typing.cast(BaseException, sys.exc_info()[1])
for user in users_needing_download.keys():
self.in_flight[user].set_exception(e)
del self.in_flight[user]
The third chunk sets the result of our Futures after a successful request.
So whether the request fails or succeeds, the Future will resolve.
self.in_flight[user].set_result(user_device_info)
del self.in_flight[user]
And the last chunk gets the result for any Futures that we are waiting on.
Note that we get the result for the Futures after we make our own request to
the server. That is because we don’t want to wait for the Futures to resolve
before we make our request; the requests can be made concurrently.
for user, future in device_futures.items():
ret[user] = await future
Tests
@pytest.mark.asyncio
async def test_concurrent_device_requests(mock_aioresponse):
async with client.Client(
storage={
"access_token": "anaccesstoken",
"user_id": "@alice:example.org",
"device_id": "ABCDEFG",
"device_tracker.tracked_users": {
"@bob:example.org": True,
},
},
callbacks={},
base_client_url="https://matrix-client.example.org/_matrix/client/",
) as c:
{{concurrent device requests test}}
To test that we can make concurrent requests, we will first make a request for a user’s device keys. We will process the request for the user’s keys using a custom handler, and make another request for that user’s device keys within that handler, to simulate a request being made concurrently. We will ensure that the device tracker does not make another request for the user’s device keys.
tracker = devices.DeviceTracker(c)
second_request_task = None
async def callback(url, **kwargs):
assert kwargs["json"] == {"device_keys": {"@bob:example.org": []}}
nonlocal second_request_task
second_request_task = asyncio.create_task(
tracker.get_device_keys(["@bob:example.org"])
)
# give the second request some time to execute to make sure that it blocks
# on the future
await asyncio.sleep(0.2)
return aioresponses.CallbackResult(
status=200,
body=json.dumps(
{
"device_keys": {
"@bob:example.org": {
"HIJKLMN": {
"algorithms": [],
"device_id": "HIJKLMN",
"keys": {
"curve25519:HIJKLMN": "some+key",
},
"user_id": "@bob:example.org",
},
},
},
}
),
headers={
"Content-Type": "application/json",
},
)
mock_aioresponse.post(
"https://matrix-client.example.org/_matrix/client/v3/keys/query",
callback=callback,
)
assert await tracker.get_device_keys(["@bob:example.org"]) == {
"@bob:example.org": devices.UserDeviceKeysResult(
{
"HIJKLMN": {
"algorithms": [],
"device_id": "HIJKLMN",
"keys": {
"curve25519:HIJKLMN": "some+key",
},
"user_id": "@bob:example.org",
},
},
),
}
assert await second_request_task == {
"@bob:example.org": devices.UserDeviceKeysResult(
{
"HIJKLMN": {
"algorithms": [],
"device_id": "HIJKLMN",
"keys": {
"curve25519:HIJKLMN": "some+key",
},
"user_id": "@bob:example.org",
},
},
),
}
Now we will handle the second edge case, where a sync comes in while we’re in the process of querying the server for a user’s devices. As explained above, we will return the result that we obtain from the server, but we will not cache the result, since it could be outdated.
When we receive the DeviceChanges message from the sync loop, we will check
whether the users that are in changed or left have an in-flight request,
and if so, we will mark those users indicating that their results should not be
cached.
self.do_not_cache: typing.Set[str] = set()
if user in self.in_flight:
self.do_not_cache.add(user)
Then, in our _should_cache_result function, we check whether the user is
marked for not caching, and if so, we return that they should not be cached
after clearing the flag. Since we know that there is no other request for that
user, it is safe to clear the flag so that we will cache their result the next
time we query their devices.
if user in self.do_not_cache:
self.do_not_cache.remove(user)
return False
Tests
@pytest.mark.asyncio
async def test_update_during_device_request(mock_aioresponse):
async with client.Client(
storage={
"access_token": "anaccesstoken",
"user_id": "@alice:example.org",
"device_id": "ABCDEFG",
"device_tracker.tracked_users": {
"@bob:example.org": True,
},
},
callbacks={},
base_client_url="https://matrix-client.example.org/_matrix/client/",
) as c:
{{update during device request test}}
To test that we can make concurrent requests, we will first make a request for
a user’s device keys. Again, we will process the request for the user’s keys
using a custom handler, but this time in our handler, we will publish a
DeviceChanges message to the tracker, simulating what would happen if a sync
request came in. We then ensure that when we make another request, it will
re-query the server.
tracker = devices.DeviceTracker(c)
async def callback1(url, **kwargs):
assert kwargs["json"] == {"device_keys": {"@bob:example.org": []}}
await c.publisher.publish(client.DeviceChanges(["@bob:example.org"], []))
return aioresponses.CallbackResult(
status=200,
body=json.dumps(
{
"device_keys": {
"@bob:example.org": {
"HIJKLMN": {
"algorithms": [],
"device_id": "HIJKLMN",
"keys": {
"curve25519:HIJKLMN": "some+key",
},
"user_id": "@bob:example.org",
},
},
},
}
),
headers={
"Content-Type": "application/json",
},
)
mock_aioresponse.post(
"https://matrix-client.example.org/_matrix/client/v3/keys/query",
callback=callback1,
)
def callback2(url, **kwargs):
assert kwargs["json"] == {"device_keys": {"@bob:example.org": []}}
return aioresponses.CallbackResult(
status=200,
body=json.dumps(
{
"device_keys": {
"@bob:example.org": {
"VWXYZAB": {
"algorithms": [],
"device_id": "VWXYZAB",
"keys": {
"curve25519:HIJKLMN": "some+new+key",
},
"user_id": "@bob:example.org",
},
},
},
}
),
headers={
"Content-Type": "application/json",
},
)
mock_aioresponse.post(
"https://matrix-client.example.org/_matrix/client/v3/keys/query",
callback=callback2,
)
assert await tracker.get_device_keys(["@bob:example.org"]) == {
"@bob:example.org": devices.UserDeviceKeysResult(
{
"HIJKLMN": {
"algorithms": [],
"device_id": "HIJKLMN",
"keys": {
"curve25519:HIJKLMN": "some+key",
},
"user_id": "@bob:example.org",
},
},
),
}
assert await tracker.get_device_keys(["@bob:example.org"]) == {
"@bob:example.org": devices.UserDeviceKeysResult(
{
"VWXYZAB": {
"algorithms": [],
"device_id": "VWXYZAB",
"keys": {
"curve25519:HIJKLMN": "some+new+key",
},
"user_id": "@bob:example.org",
},
},
),
}