Skip to content
Snippets Groups Projects
Commit c9d4a566 authored by Tulir Asokan's avatar Tulir Asokan :cat2:
Browse files

Add field path to ThriftReader errors

parent 77c0dc43
No related branches found
No related tags found
No related merge requests found
......@@ -32,7 +32,7 @@ _alpha_length = ord("z") - ord("a") + 1
class ThriftReader(io.BytesIO):
"""
ThriftReader implements decodiong the Thrift Compact protocol into Python values.
ThriftReader implements decoding the Thrift Compact protocol into Python values.
https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md
"""
......@@ -201,19 +201,20 @@ class ThriftReader(io.BytesIO):
else:
self.read_val(type)
def read_val_recursive(self, rtype: RecursiveType) -> Any:
def read_val_recursive(self, rtype: RecursiveType, field_path: str = "root") -> Any:
"""
Read any type of value from the buffer.
Args:
rtype: The exact type specification for the value to read.
field_path: The recursive field path used for debugging.
Returns:
The parsed value.
"""
if rtype.type == TType.STRUCT:
self._push_stack()
val = self.read_struct(rtype.python_type)
val = self.read_struct(rtype.python_type, field_path=field_path)
self._pop_stack()
return val
elif rtype.type == TType.MAP:
......@@ -221,35 +222,41 @@ class ThriftReader(io.BytesIO):
if length == 0:
return {}
elif key_type != rtype.key_type.type:
raise ValueError(f"Unexpected key type: expected {rtype.key_type.type.name}, "
f"got {key_type.name}")
raise ValueError(f"Unexpected key type at {field_path}: "
f"expected {rtype.key_type.type.name}, got {key_type.name}")
elif value_type != rtype.value_type.type:
raise ValueError(f"Unexpected value type: expected {rtype.value_type.type.name}, "
f"got {value_type.name}")
raise ValueError(f"Unexpected value type at {field_path}: "
f"expected {rtype.value_type.type.name}, got {value_type.name}")
return {
self.read_val_recursive(rtype.key_type): self.read_val_recursive(rtype.value_type)
for _ in range(length)
self.read_val_recursive(rtype.key_type, field_path=f"{field_path}[{i}.key]"):
self.read_val_recursive(rtype.value_type, field_path=f"{field_path}[{i}.val]")
for i in range(length)
}
elif rtype.type in (TType.LIST, TType.SET):
item_type, length = self.read_list_header()
if item_type != rtype.item_type.type:
raise ValueError(f"Unexpected item type: expected {rtype.item_type.type.name}, "
f"got {item_type.name}")
data = (self.read_val_recursive(rtype.item_type) for _ in range(length))
raise ValueError(f"Unexpected item type at {field_path}: "
f"expected {rtype.item_type.type.name}, got {item_type.name}")
data = (self.read_val_recursive(rtype.item_type, field_path=f"{field_path}[{i}]")
for i in range(length))
return set(data) if rtype.type == TType.SET else list(data)
else:
if rtype.type == TType.BINARY and rtype.python_type != bytes:
# For non-bytes python types, decode as UTF-8 and then call the
# type constructor in case it's an enum or something like that.
return rtype.python_type(self.read_val(rtype.type).decode("utf-8"))
try:
return rtype.python_type(self.read_val(rtype.type).decode("utf-8"))
except UnicodeDecodeError as e:
raise ValueError(f"Failed to decode string at {field_path}: {e}")
return self.read_val(rtype.type)
def read_struct(self, type: Type[T]) -> T:
def read_struct(self, type: Type[T], field_path: str = "root") -> T:
"""
Assuming the data in the buffer is a Thrift struct, parse it into a dataclass.
Args:
type: The Python type to parse the struct into.
field_path: The recursive field path used for debugging.
Returns:
An instance of the given type with the parsed data.
......@@ -272,9 +279,12 @@ class ThriftReader(io.BytesIO):
if expected_type == TType.BOOL:
args[field_meta.name] = True if field_type == TType.TRUE else False
else:
args[field_meta.name] = self.read_val_recursive(field_meta.rtype)
# print("Creating a", type.__name__, "with", args)
return type(**args)
fp = f"{field_path}.{field_meta.name}"
args[field_meta.name] = self.read_val_recursive(field_meta.rtype, field_path=fp)
try:
return type(**args)
except TypeError as e:
raise ValueError(f"Failed to create {type.__name__} at {field_path}") from e
def pretty_print(self, field_type: TType = TType.STRUCT, _indent: str = "", _prefix: str = ""
) -> None:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment