97 lines
4.2 KiB
Python
97 lines
4.2 KiB
Python
|
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
import logging
|
||
|
from dataclasses import dataclass, field
|
||
|
from pathlib import Path
|
||
|
from typing import Optional, Union
|
||
|
|
||
|
from .generation.configuration_utils import GenerationConfig
|
||
|
from .training_args import TrainingArguments
|
||
|
from .utils import add_start_docstrings
|
||
|
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
@add_start_docstrings(TrainingArguments.__doc__)
|
||
|
class Seq2SeqTrainingArguments(TrainingArguments):
|
||
|
"""
|
||
|
Args:
|
||
|
sortish_sampler (`bool`, *optional*, defaults to `False`):
|
||
|
Whether to use a *sortish sampler* or not. Only possible if the underlying datasets are *Seq2SeqDataset*
|
||
|
for now but will become generally available in the near future.
|
||
|
|
||
|
It sorts the inputs according to lengths in order to minimize the padding size, with a bit of randomness
|
||
|
for the training set.
|
||
|
predict_with_generate (`bool`, *optional*, defaults to `False`):
|
||
|
Whether to use generate to calculate generative metrics (ROUGE, BLEU).
|
||
|
generation_max_length (`int`, *optional*):
|
||
|
The `max_length` to use on each evaluation loop when `predict_with_generate=True`. Will default to the
|
||
|
`max_length` value of the model configuration.
|
||
|
generation_num_beams (`int`, *optional*):
|
||
|
The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. Will default to the
|
||
|
`num_beams` value of the model configuration.
|
||
|
generation_config (`str` or `Path` or [`~generation.GenerationConfig`], *optional*):
|
||
|
Allows to load a [`~generation.GenerationConfig`] from the `from_pretrained` method. This can be either:
|
||
|
|
||
|
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
|
||
|
huggingface.co.
|
||
|
- a path to a *directory* containing a configuration file saved using the
|
||
|
[`~GenerationConfig.save_pretrained`] method, e.g., `./my_model_directory/`.
|
||
|
- a [`~generation.GenerationConfig`] object.
|
||
|
"""
|
||
|
|
||
|
sortish_sampler: bool = field(default=False, metadata={"help": "Whether to use SortishSampler or not."})
|
||
|
predict_with_generate: bool = field(
|
||
|
default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
|
||
|
)
|
||
|
generation_max_length: Optional[int] = field(
|
||
|
default=None,
|
||
|
metadata={
|
||
|
"help": (
|
||
|
"The `max_length` to use on each evaluation loop when `predict_with_generate=True`. Will default "
|
||
|
"to the `max_length` value of the model configuration."
|
||
|
)
|
||
|
},
|
||
|
)
|
||
|
generation_num_beams: Optional[int] = field(
|
||
|
default=None,
|
||
|
metadata={
|
||
|
"help": (
|
||
|
"The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. Will default "
|
||
|
"to the `num_beams` value of the model configuration."
|
||
|
)
|
||
|
},
|
||
|
)
|
||
|
generation_config: Optional[Union[str, Path, GenerationConfig]] = field(
|
||
|
default=None,
|
||
|
metadata={
|
||
|
"help": "Model id, file path or url pointing to a GenerationConfig json file, to use during prediction."
|
||
|
},
|
||
|
)
|
||
|
|
||
|
def to_dict(self):
|
||
|
"""
|
||
|
Serializes this instance while replace `Enum` by their values and `GenerationConfig` by dictionaries (for JSON
|
||
|
serialization support). It obfuscates the token values by removing their value.
|
||
|
"""
|
||
|
# filter out fields that are defined as field(init=False)
|
||
|
d = super().to_dict()
|
||
|
for k, v in d.items():
|
||
|
if isinstance(v, GenerationConfig):
|
||
|
d[k] = v.to_dict()
|
||
|
return d
|