# Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. # # Use of this source code is governed by a BSD-style # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd """Contains FieldMask class.""" from google.protobuf.descriptor import FieldDescriptor class FieldMask(object): """Class for FieldMask message type.""" __slots__ = () def ToJsonString(self): """Converts FieldMask to string according to proto3 JSON spec.""" camelcase_paths = [] for path in self.paths: camelcase_paths.append(_SnakeCaseToCamelCase(path)) return ','.join(camelcase_paths) def FromJsonString(self, value): """Converts string to FieldMask according to proto3 JSON spec.""" if not isinstance(value, str): raise ValueError('FieldMask JSON value not a string: {!r}'.format(value)) self.Clear() if value: for path in value.split(','): self.paths.append(_CamelCaseToSnakeCase(path)) def IsValidForDescriptor(self, message_descriptor): """Checks whether the FieldMask is valid for Message Descriptor.""" for path in self.paths: if not _IsValidPath(message_descriptor, path): return False return True def AllFieldsFromDescriptor(self, message_descriptor): """Gets all direct fields of Message Descriptor to FieldMask.""" self.Clear() for field in message_descriptor.fields: self.paths.append(field.name) def CanonicalFormFromMask(self, mask): """Converts a FieldMask to the canonical form. Removes paths that are covered by another path. For example, "foo.bar" is covered by "foo" and will be removed if "foo" is also in the FieldMask. Then sorts all paths in alphabetical order. Args: mask: The original FieldMask to be converted. """ tree = _FieldMaskTree(mask) tree.ToFieldMask(self) def Union(self, mask1, mask2): """Merges mask1 and mask2 into this FieldMask.""" _CheckFieldMaskMessage(mask1) _CheckFieldMaskMessage(mask2) tree = _FieldMaskTree(mask1) tree.MergeFromFieldMask(mask2) tree.ToFieldMask(self) def Intersect(self, mask1, mask2): """Intersects mask1 and mask2 into this FieldMask.""" _CheckFieldMaskMessage(mask1) _CheckFieldMaskMessage(mask2) tree = _FieldMaskTree(mask1) intersection = _FieldMaskTree() for path in mask2.paths: tree.IntersectPath(path, intersection) intersection.ToFieldMask(self) def MergeMessage( self, source, destination, replace_message_field=False, replace_repeated_field=False): """Merges fields specified in FieldMask from source to destination. Args: source: Source message. destination: The destination message to be merged into. replace_message_field: Replace message field if True. Merge message field if False. replace_repeated_field: Replace repeated field if True. Append elements of repeated field if False. """ tree = _FieldMaskTree(self) tree.MergeMessage( source, destination, replace_message_field, replace_repeated_field) def _IsValidPath(message_descriptor, path): """Checks whether the path is valid for Message Descriptor.""" parts = path.split('.') last = parts.pop() for name in parts: field = message_descriptor.fields_by_name.get(name) if (field is None or field.label == FieldDescriptor.LABEL_REPEATED or field.type != FieldDescriptor.TYPE_MESSAGE): return False message_descriptor = field.message_type return last in message_descriptor.fields_by_name def _CheckFieldMaskMessage(message): """Raises ValueError if message is not a FieldMask.""" message_descriptor = message.DESCRIPTOR if (message_descriptor.name != 'FieldMask' or message_descriptor.file.name != 'google/protobuf/field_mask.proto'): raise ValueError('Message {0} is not a FieldMask.'.format( message_descriptor.full_name)) def _SnakeCaseToCamelCase(path_name): """Converts a path name from snake_case to camelCase.""" result = [] after_underscore = False for c in path_name: if c.isupper(): raise ValueError( 'Fail to print FieldMask to Json string: Path name ' '{0} must not contain uppercase letters.'.format(path_name)) if after_underscore: if c.islower(): result.append(c.upper()) after_underscore = False else: raise ValueError( 'Fail to print FieldMask to Json string: The ' 'character after a "_" must be a lowercase letter ' 'in path name {0}.'.format(path_name)) elif c == '_': after_underscore = True else: result += c if after_underscore: raise ValueError('Fail to print FieldMask to Json string: Trailing "_" ' 'in path name {0}.'.format(path_name)) return ''.join(result) def _CamelCaseToSnakeCase(path_name): """Converts a field name from camelCase to snake_case.""" result = [] for c in path_name: if c == '_': raise ValueError('Fail to parse FieldMask: Path name ' '{0} must not contain "_"s.'.format(path_name)) if c.isupper(): result += '_' result += c.lower() else: result += c return ''.join(result) class _FieldMaskTree(object): """Represents a FieldMask in a tree structure. For example, given a FieldMask "foo.bar,foo.baz,bar.baz", the FieldMaskTree will be: [_root] -+- foo -+- bar | | | +- baz | +- bar --- baz In the tree, each leaf node represents a field path. """ __slots__ = ('_root',) def __init__(self, field_mask=None): """Initializes the tree by FieldMask.""" self._root = {} if field_mask: self.MergeFromFieldMask(field_mask) def MergeFromFieldMask(self, field_mask): """Merges a FieldMask to the tree.""" for path in field_mask.paths: self.AddPath(path) def AddPath(self, path): """Adds a field path into the tree. If the field path to add is a sub-path of an existing field path in the tree (i.e., a leaf node), it means the tree already matches the given path so nothing will be added to the tree. If the path matches an existing non-leaf node in the tree, that non-leaf node will be turned into a leaf node with all its children removed because the path matches all the node's children. Otherwise, a new path will be added. Args: path: The field path to add. """ node = self._root for name in path.split('.'): if name not in node: node[name] = {} elif not node[name]: # Pre-existing empty node implies we already have this entire tree. return node = node[name] # Remove any sub-trees we might have had. node.clear() def ToFieldMask(self, field_mask): """Converts the tree to a FieldMask.""" field_mask.Clear() _AddFieldPaths(self._root, '', field_mask) def IntersectPath(self, path, intersection): """Calculates the intersection part of a field path with this tree. Args: path: The field path to calculates. intersection: The out tree to record the intersection part. """ node = self._root for name in path.split('.'): if name not in node: return elif not node[name]: intersection.AddPath(path) return node = node[name] intersection.AddLeafNodes(path, node) def AddLeafNodes(self, prefix, node): """Adds leaf nodes begin with prefix to this tree.""" if not node: self.AddPath(prefix) for name in node: child_path = prefix + '.' + name self.AddLeafNodes(child_path, node[name]) def MergeMessage( self, source, destination, replace_message, replace_repeated): """Merge all fields specified by this tree from source to destination.""" _MergeMessage( self._root, source, destination, replace_message, replace_repeated) def _StrConvert(value): """Converts value to str if it is not.""" # This file is imported by c extension and some methods like ClearField # requires string for the field name. py2/py3 has different text # type and may use unicode. if not isinstance(value, str): return value.encode('utf-8') return value def _MergeMessage( node, source, destination, replace_message, replace_repeated): """Merge all fields specified by a sub-tree from source to destination.""" source_descriptor = source.DESCRIPTOR for name in node: child = node[name] field = source_descriptor.fields_by_name[name] if field is None: raise ValueError('Error: Can\'t find field {0} in message {1}.'.format( name, source_descriptor.full_name)) if child: # Sub-paths are only allowed for singular message fields. if (field.label == FieldDescriptor.LABEL_REPEATED or field.cpp_type != FieldDescriptor.CPPTYPE_MESSAGE): raise ValueError('Error: Field {0} in message {1} is not a singular ' 'message field and cannot have sub-fields.'.format( name, source_descriptor.full_name)) if source.HasField(name): _MergeMessage( child, getattr(source, name), getattr(destination, name), replace_message, replace_repeated) continue if field.label == FieldDescriptor.LABEL_REPEATED: if replace_repeated: destination.ClearField(_StrConvert(name)) repeated_source = getattr(source, name) repeated_destination = getattr(destination, name) repeated_destination.MergeFrom(repeated_source) else: if field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE: if replace_message: destination.ClearField(_StrConvert(name)) if source.HasField(name): getattr(destination, name).MergeFrom(getattr(source, name)) else: setattr(destination, name, getattr(source, name)) def _AddFieldPaths(node, prefix, field_mask): """Adds the field paths descended from node to field_mask.""" if not node and prefix: field_mask.paths.append(prefix) return for name in sorted(node): if prefix: child_path = prefix + '.' + name else: child_path = name _AddFieldPaths(node[name], child_path, field_mask)