Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
synapse
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Container Registry
Model registry
Operate
Terraform modules
Monitor
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Timo Ley
synapse
Commits
77d0a450
Unverified
Commit
77d0a450
authored
5 years ago
by
Patrick Cloke
Committed by
GitHub
5 years ago
Browse files
Options
Downloads
Patches
Plain Diff
Add type annotations and comments to auth handler (#7063)
parent
bd5e555b
No related branches found
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
changelog.d/7063.misc
+1
-0
1 addition, 0 deletions
changelog.d/7063.misc
synapse/handlers/auth.py
+104
-89
104 additions, 89 deletions
synapse/handlers/auth.py
tox.ini
+1
-0
1 addition, 0 deletions
tox.ini
with
106 additions
and
89 deletions
changelog.d/7063.misc
0 → 100644
+
1
−
0
View file @
77d0a450
Add type annotations and comments to the auth handler.
This diff is collapsed.
Click to expand it.
synapse/handlers/auth.py
+
104
−
89
View file @
77d0a450
...
@@ -18,10 +18,10 @@ import logging
...
@@ -18,10 +18,10 @@ import logging
import
time
import
time
import
unicodedata
import
unicodedata
import
urllib.parse
import
urllib.parse
from
typing
import
Any
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
import
attr
import
attr
import
bcrypt
import
bcrypt
# type: ignore[import]
import
pymacaroons
import
pymacaroons
from
twisted.internet
import
defer
from
twisted.internet
import
defer
...
@@ -45,7 +45,7 @@ from synapse.http.site import SynapseRequest
...
@@ -45,7 +45,7 @@ from synapse.http.site import SynapseRequest
from
synapse.logging.context
import
defer_to_thread
from
synapse.logging.context
import
defer_to_thread
from
synapse.module_api
import
ModuleApi
from
synapse.module_api
import
ModuleApi
from
synapse.push.mailer
import
load_jinja2_templates
from
synapse.push.mailer
import
load_jinja2_templates
from
synapse.types
import
UserID
from
synapse.types
import
Requester
,
UserID
from
synapse.util.caches.expiringcache
import
ExpiringCache
from
synapse.util.caches.expiringcache
import
ExpiringCache
from
._base
import
BaseHandler
from
._base
import
BaseHandler
...
@@ -63,11 +63,11 @@ class AuthHandler(BaseHandler):
...
@@ -63,11 +63,11 @@ class AuthHandler(BaseHandler):
"""
"""
super
(
AuthHandler
,
self
).
__init__
(
hs
)
super
(
AuthHandler
,
self
).
__init__
(
hs
)
self
.
checkers
=
{}
# type:
d
ict[str, UserInteractiveAuthChecker]
self
.
checkers
=
{}
# type:
D
ict[str, UserInteractiveAuthChecker]
for
auth_checker_class
in
INTERACTIVE_AUTH_CHECKERS
:
for
auth_checker_class
in
INTERACTIVE_AUTH_CHECKERS
:
inst
=
auth_checker_class
(
hs
)
inst
=
auth_checker_class
(
hs
)
if
inst
.
is_enabled
():
if
inst
.
is_enabled
():
self
.
checkers
[
inst
.
AUTH_TYPE
]
=
inst
self
.
checkers
[
inst
.
AUTH_TYPE
]
=
inst
# type: ignore
self
.
bcrypt_rounds
=
hs
.
config
.
bcrypt_rounds
self
.
bcrypt_rounds
=
hs
.
config
.
bcrypt_rounds
...
@@ -124,7 +124,9 @@ class AuthHandler(BaseHandler):
...
@@ -124,7 +124,9 @@ class AuthHandler(BaseHandler):
self
.
_whitelisted_sso_clients
=
tuple
(
hs
.
config
.
sso_client_whitelist
)
self
.
_whitelisted_sso_clients
=
tuple
(
hs
.
config
.
sso_client_whitelist
)
@defer.inlineCallbacks
@defer.inlineCallbacks
def
validate_user_via_ui_auth
(
self
,
requester
,
request_body
,
clientip
):
def
validate_user_via_ui_auth
(
self
,
requester
:
Requester
,
request_body
:
Dict
[
str
,
Any
],
clientip
:
str
):
"""
"""
Checks that the user is who they claim to be, via a UI auth.
Checks that the user is who they claim to be, via a UI auth.
...
@@ -133,11 +135,11 @@ class AuthHandler(BaseHandler):
...
@@ -133,11 +135,11 @@ class AuthHandler(BaseHandler):
that it isn
'
t stolen by re-authenticating them.
that it isn
'
t stolen by re-authenticating them.
Args:
Args:
requester
(Requester)
: The user, as given by the access token
requester: The user, as given by the access token
request_body
(dict)
: The body of the request sent by the client
request_body: The body of the request sent by the client
clientip
(str)
: The IP address of the client.
clientip: The IP address of the client.
Returns:
Returns:
defer.Deferred[dict]: the parameters for this request (which may
defer.Deferred[dict]: the parameters for this request (which may
...
@@ -208,7 +210,9 @@ class AuthHandler(BaseHandler):
...
@@ -208,7 +210,9 @@ class AuthHandler(BaseHandler):
return
self
.
checkers
.
keys
()
return
self
.
checkers
.
keys
()
@defer.inlineCallbacks
@defer.inlineCallbacks
def
check_auth
(
self
,
flows
,
clientdict
,
clientip
):
def
check_auth
(
self
,
flows
:
List
[
List
[
str
]],
clientdict
:
Dict
[
str
,
Any
],
clientip
:
str
):
"""
"""
Takes a dictionary sent by the client in the login / registration
Takes a dictionary sent by the client in the login / registration
protocol and handles the User-Interactive Auth flow.
protocol and handles the User-Interactive Auth flow.
...
@@ -223,14 +227,14 @@ class AuthHandler(BaseHandler):
...
@@ -223,14 +227,14 @@ class AuthHandler(BaseHandler):
decorator.
decorator.
Args:
Args:
flows
(list)
: A list of login flows. Each flow is an ordered list of
flows: A list of login flows. Each flow is an ordered list of
strings representing auth-types. At least one full
strings representing auth-types. At least one full
flow must be completed in order for auth to be successful.
flow must be completed in order for auth to be successful.
clientdict: The dictionary from the client root level, not the
clientdict: The dictionary from the client root level, not the
'
auth
'
key: this method prompts for auth if none is sent.
'
auth
'
key: this method prompts for auth if none is sent.
clientip
(str)
: The IP address of the client.
clientip: The IP address of the client.
Returns:
Returns:
defer.Deferred[dict, dict, str]: a deferred tuple of
defer.Deferred[dict, dict, str]: a deferred tuple of
...
@@ -250,7 +254,7 @@ class AuthHandler(BaseHandler):
...
@@ -250,7 +254,7 @@ class AuthHandler(BaseHandler):
"""
"""
authdict
=
None
authdict
=
None
sid
=
None
sid
=
None
# type: Optional[str]
if
clientdict
and
"
auth
"
in
clientdict
:
if
clientdict
and
"
auth
"
in
clientdict
:
authdict
=
clientdict
[
"
auth
"
]
authdict
=
clientdict
[
"
auth
"
]
del
clientdict
[
"
auth
"
]
del
clientdict
[
"
auth
"
]
...
@@ -283,9 +287,9 @@ class AuthHandler(BaseHandler):
...
@@ -283,9 +287,9 @@ class AuthHandler(BaseHandler):
creds
=
session
[
"
creds
"
]
creds
=
session
[
"
creds
"
]
# check auth type currently being presented
# check auth type currently being presented
errordict
=
{}
errordict
=
{}
# type: Dict[str, Any]
if
"
type
"
in
authdict
:
if
"
type
"
in
authdict
:
login_type
=
authdict
[
"
type
"
]
login_type
=
authdict
[
"
type
"
]
# type: str
try
:
try
:
result
=
yield
self
.
_check_auth_dict
(
authdict
,
clientip
)
result
=
yield
self
.
_check_auth_dict
(
authdict
,
clientip
)
if
result
:
if
result
:
...
@@ -326,7 +330,7 @@ class AuthHandler(BaseHandler):
...
@@ -326,7 +330,7 @@ class AuthHandler(BaseHandler):
raise
InteractiveAuthIncompleteError
(
ret
)
raise
InteractiveAuthIncompleteError
(
ret
)
@defer.inlineCallbacks
@defer.inlineCallbacks
def
add_oob_auth
(
self
,
stagetype
,
authdict
,
clientip
):
def
add_oob_auth
(
self
,
stagetype
:
str
,
authdict
:
Dict
[
str
,
Any
]
,
clientip
:
str
):
"""
"""
Adds the result of out-of-band authentication into an existing auth
Adds the result of out-of-band authentication into an existing auth
session. Currently used for adding the result of fallback auth.
session. Currently used for adding the result of fallback auth.
...
@@ -348,7 +352,7 @@ class AuthHandler(BaseHandler):
...
@@ -348,7 +352,7 @@ class AuthHandler(BaseHandler):
return
True
return
True
return
False
return
False
def
get_session_id
(
self
,
clientdict
)
:
def
get_session_id
(
self
,
clientdict
:
Dict
[
str
,
Any
])
->
Optional
[
str
]
:
"""
"""
Gets the session ID for a client given the client dictionary
Gets the session ID for a client given the client dictionary
...
@@ -356,7 +360,7 @@ class AuthHandler(BaseHandler):
...
@@ -356,7 +360,7 @@ class AuthHandler(BaseHandler):
clientdict: The dictionary sent by the client in the request
clientdict: The dictionary sent by the client in the request
Returns:
Returns:
str|None:
The string session ID the client sent. If the client did
The string session ID the client sent. If the client did
not send a session ID, returns None.
not send a session ID, returns None.
"""
"""
sid
=
None
sid
=
None
...
@@ -366,40 +370,42 @@ class AuthHandler(BaseHandler):
...
@@ -366,40 +370,42 @@ class AuthHandler(BaseHandler):
sid
=
authdict
[
"
session
"
]
sid
=
authdict
[
"
session
"
]
return
sid
return
sid
def
set_session_data
(
self
,
session_id
,
key
,
value
)
:
def
set_session_data
(
self
,
session_id
:
str
,
key
:
str
,
value
:
Any
)
->
None
:
"""
"""
Store a key-value pair into the sessions data associated with this
Store a key-value pair into the sessions data associated with this
request. This data is stored server-side and cannot be modified by
request. This data is stored server-side and cannot be modified by
the client.
the client.
Args:
Args:
session_id
(string)
: The ID of this session as returned from check_auth
session_id: The ID of this session as returned from check_auth
key
(string)
: The key to store the data under
key: The key to store the data under
value
(any)
: The data to store
value: The data to store
"""
"""
sess
=
self
.
_get_session_info
(
session_id
)
sess
=
self
.
_get_session_info
(
session_id
)
sess
.
setdefault
(
"
serverdict
"
,
{})[
key
]
=
value
sess
.
setdefault
(
"
serverdict
"
,
{})[
key
]
=
value
self
.
_save_session
(
sess
)
self
.
_save_session
(
sess
)
def
get_session_data
(
self
,
session_id
,
key
,
default
=
None
):
def
get_session_data
(
self
,
session_id
:
str
,
key
:
str
,
default
:
Optional
[
Any
]
=
None
)
->
Any
:
"""
"""
Retrieve data stored with set_session_data
Retrieve data stored with set_session_data
Args:
Args:
session_id
(string)
: The ID of this session as returned from check_auth
session_id: The ID of this session as returned from check_auth
key
(string)
: The key to store the data under
key: The key to store the data under
default
(any)
: Value to return if the key has not been set
default: Value to return if the key has not been set
"""
"""
sess
=
self
.
_get_session_info
(
session_id
)
sess
=
self
.
_get_session_info
(
session_id
)
return
sess
.
setdefault
(
"
serverdict
"
,
{}).
get
(
key
,
default
)
return
sess
.
setdefault
(
"
serverdict
"
,
{}).
get
(
key
,
default
)
@defer.inlineCallbacks
@defer.inlineCallbacks
def
_check_auth_dict
(
self
,
authdict
,
clientip
):
def
_check_auth_dict
(
self
,
authdict
:
Dict
[
str
,
Any
]
,
clientip
:
str
):
"""
Attempt to validate the auth dict provided by a client
"""
Attempt to validate the auth dict provided by a client
Args:
Args:
authdict
(object)
: auth dict provided by the client
authdict: auth dict provided by the client
clientip
(str)
: IP address of the client
clientip: IP address of the client
Returns:
Returns:
Deferred: result of the stage verification.
Deferred: result of the stage verification.
...
@@ -425,10 +431,10 @@ class AuthHandler(BaseHandler):
...
@@ -425,10 +431,10 @@ class AuthHandler(BaseHandler):
(
canonical_id
,
callback
)
=
yield
self
.
validate_login
(
user_id
,
authdict
)
(
canonical_id
,
callback
)
=
yield
self
.
validate_login
(
user_id
,
authdict
)
return
canonical_id
return
canonical_id
def
_get_params_recaptcha
(
self
):
def
_get_params_recaptcha
(
self
)
->
dict
:
return
{
"
public_key
"
:
self
.
hs
.
config
.
recaptcha_public_key
}
return
{
"
public_key
"
:
self
.
hs
.
config
.
recaptcha_public_key
}
def
_get_params_terms
(
self
):
def
_get_params_terms
(
self
)
->
dict
:
return
{
return
{
"
policies
"
:
{
"
policies
"
:
{
"
privacy_policy
"
:
{
"
privacy_policy
"
:
{
...
@@ -445,7 +451,9 @@ class AuthHandler(BaseHandler):
...
@@ -445,7 +451,9 @@ class AuthHandler(BaseHandler):
}
}
}
}
def
_auth_dict_for_flows
(
self
,
flows
,
session
):
def
_auth_dict_for_flows
(
self
,
flows
:
List
[
List
[
str
]],
session
:
Dict
[
str
,
Any
]
)
->
Dict
[
str
,
Any
]:
public_flows
=
[]
public_flows
=
[]
for
f
in
flows
:
for
f
in
flows
:
public_flows
.
append
(
f
)
public_flows
.
append
(
f
)
...
@@ -455,7 +463,7 @@ class AuthHandler(BaseHandler):
...
@@ -455,7 +463,7 @@ class AuthHandler(BaseHandler):
LoginType
.
TERMS
:
self
.
_get_params_terms
,
LoginType
.
TERMS
:
self
.
_get_params_terms
,
}
}
params
=
{}
params
=
{}
# type: Dict[str, Any]
for
f
in
public_flows
:
for
f
in
public_flows
:
for
stage
in
f
:
for
stage
in
f
:
...
@@ -468,7 +476,13 @@ class AuthHandler(BaseHandler):
...
@@ -468,7 +476,13 @@ class AuthHandler(BaseHandler):
"
params
"
:
params
,
"
params
"
:
params
,
}
}
def
_get_session_info
(
self
,
session_id
):
def
_get_session_info
(
self
,
session_id
:
Optional
[
str
])
->
dict
:
"""
Gets or creates a session given a session ID.
The session can be used to track data across multiple requests, e.g. for
interactive authentication.
"""
if
session_id
not
in
self
.
sessions
:
if
session_id
not
in
self
.
sessions
:
session_id
=
None
session_id
=
None
...
@@ -481,7 +495,9 @@ class AuthHandler(BaseHandler):
...
@@ -481,7 +495,9 @@ class AuthHandler(BaseHandler):
return
self
.
sessions
[
session_id
]
return
self
.
sessions
[
session_id
]
@defer.inlineCallbacks
@defer.inlineCallbacks
def
get_access_token_for_user_id
(
self
,
user_id
,
device_id
,
valid_until_ms
):
def
get_access_token_for_user_id
(
self
,
user_id
:
str
,
device_id
:
Optional
[
str
],
valid_until_ms
:
Optional
[
int
]
):
"""
"""
Creates a new access token for the user with the given user ID.
Creates a new access token for the user with the given user ID.
...
@@ -491,11 +507,11 @@ class AuthHandler(BaseHandler):
...
@@ -491,11 +507,11 @@ class AuthHandler(BaseHandler):
The device will be recorded in the table if it is not there already.
The device will be recorded in the table if it is not there already.
Args:
Args:
user_id
(str)
: canonical User ID
user_id: canonical User ID
device_id
(str|None)
: the device ID to associate with the tokens.
device_id: the device ID to associate with the tokens.
None to leave the tokens unassociated with a device (deprecated:
None to leave the tokens unassociated with a device (deprecated:
we should always have a device ID)
we should always have a device ID)
valid_until_ms
(int|None)
: when the token is valid until. None for
valid_until_ms: when the token is valid until. None for
no expiry.
no expiry.
Returns:
Returns:
The access token for the user
'
s session.
The access token for the user
'
s session.
...
@@ -530,13 +546,13 @@ class AuthHandler(BaseHandler):
...
@@ -530,13 +546,13 @@ class AuthHandler(BaseHandler):
return
access_token
return
access_token
@defer.inlineCallbacks
@defer.inlineCallbacks
def
check_user_exists
(
self
,
user_id
):
def
check_user_exists
(
self
,
user_id
:
str
):
"""
"""
Checks to see if a user with the given id exists. Will check case
Checks to see if a user with the given id exists. Will check case
insensitively, but return None if there are multiple inexact matches.
insensitively, but return None if there are multiple inexact matches.
Args:
Args:
(unicode|bytes)
user_id: complete @user:id
user_id: complete @user:id
Returns:
Returns:
defer.Deferred: (unicode) canonical_user_id, or None if zero or
defer.Deferred: (unicode) canonical_user_id, or None if zero or
...
@@ -551,7 +567,7 @@ class AuthHandler(BaseHandler):
...
@@ -551,7 +567,7 @@ class AuthHandler(BaseHandler):
return
None
return
None
@defer.inlineCallbacks
@defer.inlineCallbacks
def
_find_user_id_and_pwd_hash
(
self
,
user_id
):
def
_find_user_id_and_pwd_hash
(
self
,
user_id
:
str
):
"""
Checks to see if a user with the given id exists. Will check case
"""
Checks to see if a user with the given id exists. Will check case
insensitively, but will return None if there are multiple inexact
insensitively, but will return None if there are multiple inexact
matches.
matches.
...
@@ -581,7 +597,7 @@ class AuthHandler(BaseHandler):
...
@@ -581,7 +597,7 @@ class AuthHandler(BaseHandler):
)
)
return
result
return
result
def
get_supported_login_types
(
self
):
def
get_supported_login_types
(
self
)
->
Iterable
[
str
]
:
"""
Get a the login types supported for the /login API
"""
Get a the login types supported for the /login API
By default this is just
'
m.login.password
'
(unless password_enabled is
By default this is just
'
m.login.password
'
(unless password_enabled is
...
@@ -589,20 +605,20 @@ class AuthHandler(BaseHandler):
...
@@ -589,20 +605,20 @@ class AuthHandler(BaseHandler):
other login types.
other login types.
Returns:
Returns:
Iterable[str]:
login types
login types
"""
"""
return
self
.
_supported_login_types
return
self
.
_supported_login_types
@defer.inlineCallbacks
@defer.inlineCallbacks
def
validate_login
(
self
,
username
,
login_submission
):
def
validate_login
(
self
,
username
:
str
,
login_submission
:
Dict
[
str
,
Any
]
):
"""
Authenticates the user for the /login API
"""
Authenticates the user for the /login API
Also used by the user-interactive auth flow to validate
Also used by the user-interactive auth flow to validate
m.login.password auth types.
m.login.password auth types.
Args:
Args:
username
(str)
: username supplied by the user
username: username supplied by the user
login_submission
(dict)
: the whole of the login submission
login_submission: the whole of the login submission
(including
'
type
'
and other relevant fields)
(including
'
type
'
and other relevant fields)
Returns:
Returns:
Deferred[str, func]: canonical user id, and optional callback
Deferred[str, func]: canonical user id, and optional callback
...
@@ -690,13 +706,13 @@ class AuthHandler(BaseHandler):
...
@@ -690,13 +706,13 @@ class AuthHandler(BaseHandler):
raise
LoginError
(
403
,
"
Invalid password
"
,
errcode
=
Codes
.
FORBIDDEN
)
raise
LoginError
(
403
,
"
Invalid password
"
,
errcode
=
Codes
.
FORBIDDEN
)
@defer.inlineCallbacks
@defer.inlineCallbacks
def
check_password_provider_3pid
(
self
,
medium
,
address
,
password
):
def
check_password_provider_3pid
(
self
,
medium
:
str
,
address
:
str
,
password
:
str
):
"""
Check if a password provider is able to validate a thirdparty login
"""
Check if a password provider is able to validate a thirdparty login
Args:
Args:
medium
(str)
: The medium of the 3pid (ex. email).
medium: The medium of the 3pid (ex. email).
address
(str)
: The address of the 3pid (ex. jdoe@example.com).
address: The address of the 3pid (ex. jdoe@example.com).
password
(str)
: The password of the user.
password: The password of the user.
Returns:
Returns:
Deferred[(str|None, func|None)]: A tuple of `(user_id,
Deferred[(str|None, func|None)]: A tuple of `(user_id,
...
@@ -724,15 +740,15 @@ class AuthHandler(BaseHandler):
...
@@ -724,15 +740,15 @@ class AuthHandler(BaseHandler):
return
None
,
None
return
None
,
None
@defer.inlineCallbacks
@defer.inlineCallbacks
def
_check_local_password
(
self
,
user_id
,
password
):
def
_check_local_password
(
self
,
user_id
:
str
,
password
:
str
):
"""
Authenticate a user against the local password database.
"""
Authenticate a user against the local password database.
user_id is checked case insensitively, but will return None if there are
user_id is checked case insensitively, but will return None if there are
multiple inexact matches.
multiple inexact matches.
Args:
Args:
user_id
(unicode)
: complete @user:id
user_id: complete @user:id
password
(unicode)
: the provided password
password: the provided password
Returns:
Returns:
Deferred[unicode] the canonical_user_id, or Deferred[None] if
Deferred[unicode] the canonical_user_id, or Deferred[None] if
unknown user/bad password
unknown user/bad password
...
@@ -755,7 +771,7 @@ class AuthHandler(BaseHandler):
...
@@ -755,7 +771,7 @@ class AuthHandler(BaseHandler):
return
user_id
return
user_id
@defer.inlineCallbacks
@defer.inlineCallbacks
def
validate_short_term_login_token_and_get_user_id
(
self
,
login_token
):
def
validate_short_term_login_token_and_get_user_id
(
self
,
login_token
:
str
):
auth_api
=
self
.
hs
.
get_auth
()
auth_api
=
self
.
hs
.
get_auth
()
user_id
=
None
user_id
=
None
try
:
try
:
...
@@ -769,11 +785,11 @@ class AuthHandler(BaseHandler):
...
@@ -769,11 +785,11 @@ class AuthHandler(BaseHandler):
return
user_id
return
user_id
@defer.inlineCallbacks
@defer.inlineCallbacks
def
delete_access_token
(
self
,
access_token
):
def
delete_access_token
(
self
,
access_token
:
str
):
"""
Invalidate a single access token
"""
Invalidate a single access token
Args:
Args:
access_token
(str)
: access token to be deleted
access_token: access token to be deleted
Returns:
Returns:
Deferred
Deferred
...
@@ -798,15 +814,17 @@ class AuthHandler(BaseHandler):
...
@@ -798,15 +814,17 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks
@defer.inlineCallbacks
def
delete_access_tokens_for_user
(
def
delete_access_tokens_for_user
(
self
,
user_id
,
except_token_id
=
None
,
device_id
=
None
self
,
user_id
:
str
,
except_token_id
:
Optional
[
str
]
=
None
,
device_id
:
Optional
[
str
]
=
None
,
):
):
"""
Invalidate access tokens belonging to a user
"""
Invalidate access tokens belonging to a user
Args:
Args:
user_id (str): ID of user the tokens belong to
user_id: ID of user the tokens belong to
except_token_id (str|None): access_token ID which should *not* be
except_token_id: access_token ID which should *not* be deleted
deleted
device_id: ID of device the tokens are associated with.
device_id (str|None): ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will
If None, tokens associated with any device (or no device) will
be deleted
be deleted
Returns:
Returns:
...
@@ -830,7 +848,7 @@ class AuthHandler(BaseHandler):
...
@@ -830,7 +848,7 @@ class AuthHandler(BaseHandler):
)
)
@defer.inlineCallbacks
@defer.inlineCallbacks
def
add_threepid
(
self
,
user_id
,
medium
,
address
,
validated_at
):
def
add_threepid
(
self
,
user_id
:
str
,
medium
:
str
,
address
:
str
,
validated_at
:
int
):
# check if medium has a valid value
# check if medium has a valid value
if
medium
not
in
[
"
email
"
,
"
msisdn
"
]:
if
medium
not
in
[
"
email
"
,
"
msisdn
"
]:
raise
SynapseError
(
raise
SynapseError
(
...
@@ -856,19 +874,20 @@ class AuthHandler(BaseHandler):
...
@@ -856,19 +874,20 @@ class AuthHandler(BaseHandler):
)
)
@defer.inlineCallbacks
@defer.inlineCallbacks
def
delete_threepid
(
self
,
user_id
,
medium
,
address
,
id_server
=
None
):
def
delete_threepid
(
self
,
user_id
:
str
,
medium
:
str
,
address
:
str
,
id_server
:
Optional
[
str
]
=
None
):
"""
Attempts to unbind the 3pid on the identity servers and deletes it
"""
Attempts to unbind the 3pid on the identity servers and deletes it
from the local database.
from the local database.
Args:
Args:
user_id
(str)
user_id
: ID of user to remove the 3pid from.
medium
(str)
medium
: The medium of the 3pid being removed:
"
email
"
or
"
msisdn
"
.
address
(str)
address
: The 3pid address to remove.
id_server
(str|None)
: Use the given identity server when unbinding
id_server: Use the given identity server when unbinding
any threepids. If None then will attempt to unbind using the
any threepids. If None then will attempt to unbind using the
identity server specified when binding (if known).
identity server specified when binding (if known).
Returns:
Returns:
Deferred[bool]: Returns True if successfully unbound the 3pid on
Deferred[bool]: Returns True if successfully unbound the 3pid on
the identity server, False if identity server doesn
'
t support the
the identity server, False if identity server doesn
'
t support the
...
@@ -887,17 +906,18 @@ class AuthHandler(BaseHandler):
...
@@ -887,17 +906,18 @@ class AuthHandler(BaseHandler):
yield
self
.
store
.
user_delete_threepid
(
user_id
,
medium
,
address
)
yield
self
.
store
.
user_delete_threepid
(
user_id
,
medium
,
address
)
return
result
return
result
def
_save_session
(
self
,
session
):
def
_save_session
(
self
,
session
:
Dict
[
str
,
Any
])
->
None
:
"""
Update the last used time on the session to now and add it back to the session store.
"""
# TODO: Persistent storage
# TODO: Persistent storage
logger
.
debug
(
"
Saving session %s
"
,
session
)
logger
.
debug
(
"
Saving session %s
"
,
session
)
session
[
"
last_used
"
]
=
self
.
hs
.
get_clock
().
time_msec
()
session
[
"
last_used
"
]
=
self
.
hs
.
get_clock
().
time_msec
()
self
.
sessions
[
session
[
"
id
"
]]
=
session
self
.
sessions
[
session
[
"
id
"
]]
=
session
def
hash
(
self
,
password
):
def
hash
(
self
,
password
:
str
):
"""
Computes a secure hash of password.
"""
Computes a secure hash of password.
Args:
Args:
password
(unicode)
: Password to hash.
password: Password to hash.
Returns:
Returns:
Deferred(unicode): Hashed password.
Deferred(unicode): Hashed password.
...
@@ -914,12 +934,12 @@ class AuthHandler(BaseHandler):
...
@@ -914,12 +934,12 @@ class AuthHandler(BaseHandler):
return
defer_to_thread
(
self
.
hs
.
get_reactor
(),
_do_hash
)
return
defer_to_thread
(
self
.
hs
.
get_reactor
(),
_do_hash
)
def
validate_hash
(
self
,
password
,
stored_hash
):
def
validate_hash
(
self
,
password
:
str
,
stored_hash
:
bytes
):
"""
Validates that self.hash(password) == stored_hash.
"""
Validates that self.hash(password) == stored_hash.
Args:
Args:
password
(unicode)
: Password to hash.
password: Password to hash.
stored_hash
(bytes)
: Expected hash value.
stored_hash: Expected hash value.
Returns:
Returns:
Deferred(bool): Whether self.hash(password) == stored_hash.
Deferred(bool): Whether self.hash(password) == stored_hash.
...
@@ -1007,7 +1027,9 @@ class MacaroonGenerator(object):
...
@@ -1007,7 +1027,9 @@ class MacaroonGenerator(object):
hs
=
attr
.
ib
()
hs
=
attr
.
ib
()
def
generate_access_token
(
self
,
user_id
,
extra_caveats
=
None
):
def
generate_access_token
(
self
,
user_id
:
str
,
extra_caveats
:
Optional
[
List
[
str
]]
=
None
)
->
str
:
extra_caveats
=
extra_caveats
or
[]
extra_caveats
=
extra_caveats
or
[]
macaroon
=
self
.
_generate_base_macaroon
(
user_id
)
macaroon
=
self
.
_generate_base_macaroon
(
user_id
)
macaroon
.
add_first_party_caveat
(
"
type = access
"
)
macaroon
.
add_first_party_caveat
(
"
type = access
"
)
...
@@ -1020,16 +1042,9 @@ class MacaroonGenerator(object):
...
@@ -1020,16 +1042,9 @@ class MacaroonGenerator(object):
macaroon
.
add_first_party_caveat
(
caveat
)
macaroon
.
add_first_party_caveat
(
caveat
)
return
macaroon
.
serialize
()
return
macaroon
.
serialize
()
def
generate_short_term_login_token
(
self
,
user_id
,
duration_in_ms
=
(
2
*
60
*
1000
)):
def
generate_short_term_login_token
(
"""
self
,
user_id
:
str
,
duration_in_ms
:
int
=
(
2
*
60
*
1000
)
)
->
str
:
Args:
user_id (unicode):
duration_in_ms (int):
Returns:
unicode
"""
macaroon
=
self
.
_generate_base_macaroon
(
user_id
)
macaroon
=
self
.
_generate_base_macaroon
(
user_id
)
macaroon
.
add_first_party_caveat
(
"
type = login
"
)
macaroon
.
add_first_party_caveat
(
"
type = login
"
)
now
=
self
.
hs
.
get_clock
().
time_msec
()
now
=
self
.
hs
.
get_clock
().
time_msec
()
...
@@ -1037,12 +1052,12 @@ class MacaroonGenerator(object):
...
@@ -1037,12 +1052,12 @@ class MacaroonGenerator(object):
macaroon
.
add_first_party_caveat
(
"
time < %d
"
%
(
expiry
,))
macaroon
.
add_first_party_caveat
(
"
time < %d
"
%
(
expiry
,))
return
macaroon
.
serialize
()
return
macaroon
.
serialize
()
def
generate_delete_pusher_token
(
self
,
user_id
)
:
def
generate_delete_pusher_token
(
self
,
user_id
:
str
)
->
str
:
macaroon
=
self
.
_generate_base_macaroon
(
user_id
)
macaroon
=
self
.
_generate_base_macaroon
(
user_id
)
macaroon
.
add_first_party_caveat
(
"
type = delete_pusher
"
)
macaroon
.
add_first_party_caveat
(
"
type = delete_pusher
"
)
return
macaroon
.
serialize
()
return
macaroon
.
serialize
()
def
_generate_base_macaroon
(
self
,
user_id
)
:
def
_generate_base_macaroon
(
self
,
user_id
:
str
)
->
pymacaroons
.
Macaroon
:
macaroon
=
pymacaroons
.
Macaroon
(
macaroon
=
pymacaroons
.
Macaroon
(
location
=
self
.
hs
.
config
.
server_name
,
location
=
self
.
hs
.
config
.
server_name
,
identifier
=
"
key
"
,
identifier
=
"
key
"
,
...
...
This diff is collapsed.
Click to expand it.
tox.ini
+
1
−
0
View file @
77d0a450
...
@@ -185,6 +185,7 @@ commands = mypy \
...
@@ -185,6 +185,7 @@ commands = mypy \
synapse/federation/federation_client.py
\
synapse/federation/federation_client.py
\
synapse/federation/sender
\
synapse/federation/sender
\
synapse/federation/transport
\
synapse/federation/transport
\
synapse/handlers/auth.py
\
synapse/handlers/directory.py
\
synapse/handlers/directory.py
\
synapse/handlers/presence.py
\
synapse/handlers/presence.py
\
synapse/handlers/sync.py
\
synapse/handlers/sync.py
\
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment