Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 11, 2025
1 parent 4bb3b5a commit d619d92
Show file tree
Hide file tree
Showing 89 changed files with 30,309 additions and 29,780 deletions.
2 changes: 1 addition & 1 deletion argilla-server/src/argilla_server/_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def _show_telemetry_warning():
" https://docs.argilla.io/latest/reference/argilla-server/telemetry/\n\n"
"Telemetry is currently enabled. If you want to disable it, you can configure\n"
"the environment variable before relaunching the server:\n\n"
f'{"#set HF_HUB_DISABLE_TELEMETRY=1" if os.name == "nt" else "$>export HF_HUB_DISABLE_TELEMETRY=1"}'
f"{'#set HF_HUB_DISABLE_TELEMETRY=1' if os.name == 'nt' else '$>export HF_HUB_DISABLE_TELEMETRY=1'}"
)
_LOGGER.warning(message)

Expand Down
2 changes: 1 addition & 1 deletion argilla-v1/src/argilla_v1/client/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,7 +1156,7 @@ def _prepare_for_training_with_spacy(self, nlp: "spacy.Language", records: List[
raise ValueError(
"The following annotation does not align with the tokens"
" produced by the provided spacy language model:"
f" {(anno[0], record.text[anno[1]:anno[2]])}, {list(doc)}"
f" {(anno[0], record.text[anno[1] : anno[2]])}, {list(doc)}"
)
else:
entities.append(span)
Expand Down
4 changes: 2 additions & 2 deletions argilla-v1/src/argilla_v1/client/feedback/dataset/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def normalize_records(
new_records.append(record)
else:
raise ValueError(
"Expected `records` to be a list of `dict` or `FeedbackRecord`," f" got type `{type(record)}` instead."
f"Expected `records` to be a list of `dict` or `FeedbackRecord`, got type `{type(record)}` instead."
)
return new_records

Expand Down Expand Up @@ -384,7 +384,7 @@ def _validate_record_metadata(record: FeedbackRecord, metadata_schema: typing.Ty
metadata_schema.parse_obj(record.metadata)
except ValidationError as e:
raise ValueError(
f"`FeedbackRecord.metadata` {record.metadata} does not match the expected schema," f" with exception: {e}"
f"`FeedbackRecord.metadata` {record.metadata} does not match the expected schema, with exception: {e}"
) from e


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def __getitem__(self, key: Union[slice, int]) -> Union["FeedbackRecord", List["F
"""
if len(self._records) < 1:
raise RuntimeError(
"In order to get items from `FeedbackDataset` you need to add them first" " with `add_records`."
"In order to get items from `FeedbackDataset` you need to add them first with `add_records`."
)
if isinstance(key, int) and len(self._records) < key:
raise IndexError(f"This dataset contains {len(self)} records, so index {key} is out of range.")
Expand Down Expand Up @@ -331,8 +331,7 @@ def delete_vectors_settings(

if not self.vectors_settings:
raise ValueError(
"The current `FeedbackDataset` does not contain any `vectors_settings` defined, so"
" none can be deleted."
"The current `FeedbackDataset` does not contain any `vectors_settings` defined, so none can be deleted."
)

if not all(vector_setting in self._vectors_settings.keys() for vector_setting in vectors_settings):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __delete_dataset(client: "httpx.Client", id: UUID) -> None:
datasets_api_v1.delete_dataset(client=client, id=id)
except Exception as e:
raise Exception(
f"Failed while deleting the `FeedbackDataset` with ID '{id}' from Argilla with" f" exception: {e}"
f"Failed while deleting the `FeedbackDataset` with ID '{id}' from Argilla with exception: {e}"
) from e

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ def generate(model_id: str, instruction: str, context: str = "") -> str:
)
return tokenizer.decode(outputs[0])
generate("{self.output_dir.replace('"', '')}", "Is a toad a frog?")"""
generate("{self.output_dir.replace('"', "")}", "Is a toad a frog?")"""
)
elif self.task_type == "for_reward_modeling":
return predict_call + dedent(
Expand Down
2 changes: 1 addition & 1 deletion argilla-v1/src/argilla_v1/client/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def __init__(
raise AssertionError("Missing fields: At least one of `text` or `tokens` argument must be provided!")

if (data.get("annotation") or data.get("prediction")) and text is None:
raise AssertionError("Missing field `text`: " "char level spans must be provided with a raw text sentence")
raise AssertionError("Missing field `text`: char level spans must be provided with a raw text sentence")

if text is None:
text = " ".join(tokens)
Expand Down
4 changes: 2 additions & 2 deletions argilla-v1/src/argilla_v1/client/sdk/commons/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, message: str, response: Any):
self.response = response

def __str__(self):
return f"\nUnexpected response: {self.message}" "\nResponse content:" f"\n{self.response}"
return f"\nUnexpected response: {self.message}\nResponse content:\n{self.response}"


class InputValueError(BaseClientError):
Expand All @@ -52,7 +52,7 @@ def __init__(self, **ctx):
self.ctx = ctx

def __str__(self):
return f"Argilla server returned an error with http status: {self.HTTP_STATUS}. " f"Error details: {self.ctx!r}"
return f"Argilla server returned an error with http status: {self.HTTP_STATUS}. Error details: {self.ctx!r}"


class BadRequestApiError(ArApiResponseError):
Expand Down
6 changes: 2 additions & 4 deletions argilla-v1/src/argilla_v1/client/workspaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,7 @@ def users(self) -> List["UserModel"]:

def __repr__(self) -> str:
return (
f"Workspace(id={self.id}, name={self.name},"
f" inserted_at={self.inserted_at}, updated_at={self.updated_at})"
f"Workspace(id={self.id}, name={self.name}, inserted_at={self.inserted_at}, updated_at={self.updated_at})"
)

@allowed_for_roles(roles=[UserRole.owner])
Expand Down Expand Up @@ -330,8 +329,7 @@ def from_id(cls, id: UUID) -> "Workspace":
) from e
except ValidationApiError as e:
raise ValueError(
"The ID you provided is not a valid UUID, so please make sure that the"
" ID you provided is a valid one."
"The ID you provided is not a valid UUID, so please make sure that the ID you provided is a valid one."
) from e
except BaseClientError as e:
raise RuntimeError(f"Error while retrieving workspace with id=`{id}` from Argilla.") from e
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,7 @@ def _make_single_label_records(
pred_for_rec = [(self._weak_labels.labels[idx], prob[idx]) for idx in np.argsort(prob)[::-1]]
else:
raise NotImplementedError(
f"The tie break policy '{tie_break_policy.value}' is not"
f" implemented for {self.__class__.__name__}!"
f"The tie break policy '{tie_break_policy.value}' is not implemented for {self.__class__.__name__}!"
)

records_with_prediction.append(rec.copy(deep=True))
Expand Down
2 changes: 1 addition & 1 deletion argilla-v1/src/argilla_v1/training/autotrain_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def __repr__(self):
formatted_string.append(arg_dict_key)
for idx, item in enumerate(arg_dict_single):
for key, val in item.items():
formatted_string.append(f"\tjob{idx+1}-{key}: {val}")
formatted_string.append(f"\tjob{idx + 1}-{key}: {val}")
return "\n".join(formatted_string)

def train(self, output_dir: str):
Expand Down
6 changes: 3 additions & 3 deletions argilla-v1/tests/integration/client/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,21 +130,21 @@ def test_token_classification_with_tokens_and_tags(tokens, tags, annotation):
def test_token_classification_validations():
with pytest.raises(
AssertionError,
match=("Missing fields: " "At least one of `text` or `tokens` argument must be provided!"),
match=("Missing fields: At least one of `text` or `tokens` argument must be provided!"),
):
TokenClassificationRecord()

tokens = ["test", "text"]
annotation = [("test", 0, 4)]
with pytest.raises(
AssertionError,
match=("Missing field `text`: " "char level spans must be provided with a raw text sentence"),
match=("Missing field `text`: char level spans must be provided with a raw text sentence"),
):
TokenClassificationRecord(tokens=tokens, annotation=annotation)

with pytest.raises(
AssertionError,
match=("Missing field `text`: " "char level spans must be provided with a raw text sentence"),
match=("Missing field `text`: char level spans must be provided with a raw text sentence"),
):
TokenClassificationRecord(tokens=tokens, prediction=annotation)

Expand Down
2 changes: 1 addition & 1 deletion argilla-v1/tests/unit/client/sdk/models/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def check_schema_props(client_props: dict, server_props: dict) -> bool:
continue
if name not in server_props:
LOGGER.warning(
f"Client property {name} not found in server properties. " "Make sure your API compatibility"
f"Client property {name} not found in server properties. Make sure your API compatibility"
)
different_props.append(name)
continue
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,7 @@
" \".svg\",\n",
" \".ico\",\n",
" \".json\",\n",
" \".ipynb\", # Erase this line if you want to include notebooks\n",
"\n",
" \".ipynb\", # Erase this line if you want to include notebooks\n",
" ],\n",
" GithubRepositoryReader.FilterType.EXCLUDE,\n",
" ),\n",
Expand Down Expand Up @@ -231,9 +230,7 @@
"outputs": [],
"source": [
"# LLM settings\n",
"Settings.llm = OpenAI(\n",
" model=\"gpt-3.5-turbo\", temperature=0.8, openai_api_key=openai_api_key\n",
")\n",
"Settings.llm = OpenAI(model=\"gpt-3.5-turbo\", temperature=0.8, openai_api_key=openai_api_key)\n",
"\n",
"# Load the data and create the index\n",
"index = VectorStoreIndex.from_documents(documents)\n",
Expand Down
6 changes: 3 additions & 3 deletions argilla/docs/scripts/gen_popular_issues.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,21 +116,21 @@ def fetch_data_from_github(repository, auth_token):
f.write(" | Rank | Issue | Reactions | Comments |\n")
f.write(" |------|-------|:---------:|:--------:|\n")
for ix, row in engagement_df.iterrows():
f.write(f" | {ix+1} | [{row['Issue']}]({row['URL']}) | 👍 {row['Reactions']} | 💬 {row['Comments']} |\n")
f.write(f" | {ix + 1} | [{row['Issue']}]({row['URL']}) | 👍 {row['Reactions']} | 💬 {row['Comments']} |\n")

f.write('\n=== "Latest issues open by the community"\n\n')
f.write(" | Rank | Issue | Author |\n")
f.write(" |------|-------|:------:|\n")
for ix, row in community_issues_df.iterrows():
state = "🟢" if row["State"] == "open" else "🟣"
f.write(f" | {ix+1} | {state} [{row['Issue']}]({row['URL']}) | by **{row['Author']}** |\n")
f.write(f" | {ix + 1} | {state} [{row['Issue']}]({row['URL']}) | by **{row['Author']}** |\n")

f.write('\n=== "Planned issues for upcoming releases"\n\n')
f.write(" | Rank | Issue | Milestone |\n")
f.write(" |------|-------|:------:|\n")
for ix, row in planned_issues_df.iterrows():
state = "🟢" if row["State"] == "open" else "🟣"
f.write(f" | {ix+1} | {state} [{row['Issue']}]({row['URL']}) | **{row['Milestone']}** |\n")
f.write(f" | {ix + 1} | {state} [{row['Issue']}]({row['URL']}) | **{row['Milestone']}** |\n")

today = datetime.today().date()
f.write(f"\nLast update: {today}\n")
53 changes: 26 additions & 27 deletions argilla/docs/tutorials/image_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,7 @@
"from PIL import Image\n",
"\n",
"from datasets import load_dataset, Dataset, load_metric\n",
"from transformers import (\n",
" AutoImageProcessor,\n",
" AutoModelForImageClassification,\n",
" pipeline,\n",
" Trainer,\n",
" TrainingArguments\n",
")\n",
"from transformers import AutoImageProcessor, AutoModelForImageClassification, pipeline, Trainer, TrainingArguments\n",
"\n",
"import argilla as rg"
]
Expand Down Expand Up @@ -182,7 +176,7 @@
" title=\"What digit do you see on the image?\",\n",
" labels=labels,\n",
" )\n",
" ]\n",
" ],\n",
")"
]
},
Expand Down Expand Up @@ -246,7 +240,7 @@
"n_rows = 100\n",
"\n",
"hf_dataset = load_dataset(\"ylecun/mnist\", streaming=True)\n",
"dataset_rows = [row for _,row in zip(range(n_rows), hf_dataset[\"train\"])]\n",
"dataset_rows = [row for _, row in zip(range(n_rows), hf_dataset[\"train\"])]\n",
"hf_dataset = Dataset.from_list(dataset_rows)\n",
"\n",
"hf_dataset"
Expand Down Expand Up @@ -525,7 +519,8 @@
],
"source": [
"def greyscale_to_rgb(img) -> Image:\n",
" return Image.merge('RGB', (img, img, img))\n",
" return Image.merge(\"RGB\", (img, img, img))\n",
"\n",
"\n",
"submitted_image_rgb = [\n",
" {\n",
Expand Down Expand Up @@ -556,7 +551,7 @@
"\n",
"submitted_image_rgb_processed = [\n",
" {\n",
" \"pixel_values\": processor(sample[\"image\"], return_tensors='pt')[\"pixel_values\"],\n",
" \"pixel_values\": processor(sample[\"image\"], return_tensors=\"pt\")[\"pixel_values\"],\n",
" \"label\": sample[\"label\"],\n",
" }\n",
" for sample in submitted_image_rgb\n",
Expand Down Expand Up @@ -624,8 +619,8 @@
"source": [
"def collate_fn(batch):\n",
" return {\n",
" 'pixel_values': torch.stack([torch.tensor(x['pixel_values'][0]) for x in batch]),\n",
" 'labels': torch.tensor([int(x['label']) for x in batch])\n",
" \"pixel_values\": torch.stack([torch.tensor(x[\"pixel_values\"][0]) for x in batch]),\n",
" \"labels\": torch.tensor([int(x[\"label\"]) for x in batch]),\n",
" }"
]
},
Expand All @@ -643,6 +638,8 @@
"outputs": [],
"source": [
"metric = load_metric(\"accuracy\", trust_remote_code=True)\n",
"\n",
"\n",
"def compute_metrics(p):\n",
" return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)"
]
Expand All @@ -664,7 +661,7 @@
" checkpoint,\n",
" num_labels=len(labels),\n",
" id2label={int(i): int(c) for i, c in enumerate(labels)},\n",
" label2id={int(c): int(i) for i, c in enumerate(labels)}\n",
" label2id={int(c): int(i) for i, c in enumerate(labels)},\n",
")\n",
"model.config"
]
Expand Down Expand Up @@ -698,19 +695,19 @@
],
"source": [
"training_args = TrainingArguments(\n",
" output_dir=\"./image-classifier\",\n",
" per_device_train_batch_size=16,\n",
" eval_strategy=\"steps\",\n",
" num_train_epochs=1,\n",
" fp16=False, # True if you have a GPU with mixed precision support\n",
" save_steps=100,\n",
" eval_steps=100,\n",
" logging_steps=10,\n",
" learning_rate=2e-4,\n",
" save_total_limit=2,\n",
" remove_unused_columns=True,\n",
" push_to_hub=False,\n",
" load_best_model_at_end=True,\n",
" output_dir=\"./image-classifier\",\n",
" per_device_train_batch_size=16,\n",
" eval_strategy=\"steps\",\n",
" num_train_epochs=1,\n",
" fp16=False, # True if you have a GPU with mixed precision support\n",
" save_steps=100,\n",
" eval_steps=100,\n",
" logging_steps=10,\n",
" learning_rate=2e-4,\n",
" save_total_limit=2,\n",
" remove_unused_columns=True,\n",
" push_to_hub=False,\n",
" load_best_model_at_end=True,\n",
")\n",
"\n",
"trainer = Trainer(\n",
Expand Down Expand Up @@ -745,12 +742,14 @@
"source": [
"pipe = pipeline(\"image-classification\", model=model, image_processor=processor)\n",
"\n",
"\n",
"def run_inference(batch):\n",
" predictions = pipe(batch[\"image\"])\n",
" batch[\"image_label\"] = [prediction[0][\"label\"] for prediction in predictions]\n",
" batch[\"score\"] = [prediction[0][\"score\"] for prediction in predictions]\n",
" return batch\n",
"\n",
"\n",
"hf_dataset = hf_dataset.map(run_inference, batched=True)"
]
},
Expand Down
Loading

0 comments on commit d619d92

Please sign in to comment.