[go: nahoru, domu]

Skip to content

Commit

Permalink
Merge branch 'main' into refactor-datagrid-filter
Browse files Browse the repository at this point in the history
  • Loading branch information
c-bata committed Nov 28, 2023
2 parents 18ed414 + 1759389 commit fc77515
Show file tree
Hide file tree
Showing 34 changed files with 1,704 additions and 851 deletions.
3 changes: 2 additions & 1 deletion .eslintrc.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ module.exports = {
'@typescript-eslint',
],
rules: {
"@typescript-eslint/ban-ts-comment": "off"
"@typescript-eslint/ban-ts-comment": "off",
"eqeqeq": ["error", "smart"],
},
extends: [
'eslint:recommended',
Expand Down
8 changes: 7 additions & 1 deletion .github/workflows/e2e-dashboard-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,10 @@ jobs:
run: playwright install

- name: Run e2e tests
run: pytest e2e_tests/test_dashboard
run: |
if [ "${{ matrix.optuna-version }}" = "optuna==2.10.0" ]; then
ignore_option="--ignore e2e_tests/test_dashboard/test_usecases/test_preferential_optimization.py"
else
ignore_option=""
fi
pytest e2e_tests/test_dashboard $ignore_option
4 changes: 4 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ $ pytest python_tests/

```
$ pip install -r requirements.txt
$ playwright install
$ pytest e2e_tests
```

Expand All @@ -92,6 +93,9 @@ If you want to create a screenshot for each test, please run a following command
$ pytest e2e_tests --screenshot on --output tmp
```

If you want to generate a locator in each webpage, please use the playwright codegen. See [this page](https://playwright.dev/python/docs/codegen-intro) for more details.


For more detail options, you can check [this page](https://playwright.dev/python/docs/test-runners).

#### Linters (flake8, black and mypy)
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
![Software License](https://img.shields.io/badge/license-MIT-brightgreen.svg?style=flat-square)
[![PyPI - Downloads](https://img.shields.io/pypi/dm/optuna-dashboard)](https://pypistats.org/packages/optuna-dashboard)
[![Read the Docs](https://readthedocs.org/projects/optuna-dashboard/badge/?version=latest)](https://optuna-dashboard.readthedocs.io/en/latest/?badge=latest)
[![Codecov](https://codecov.io/gh/optuna/optuna-dashboard/branch/main/graph/badge.svg)](https://codecov.io/gh/optuna/optuna-dashboard)


Real-time dashboard for [Optuna](https://github.com/optuna/optuna).
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import re

import optuna
from optuna.trial import TrialState
from optuna_dashboard import ChoiceWidget
from optuna_dashboard import register_objective_form_widgets
from playwright.sync_api import expect
from playwright.sync_api import Page
import pytest

from ...test_server import make_test_server


def make_test_storage() -> optuna.storages.InMemoryStorage:
storage = optuna.storages.InMemoryStorage()
sampler = optuna.samplers.RandomSampler(seed=0)

study = optuna.create_study(
study_name="preferential_optimization",
storage=storage,
sampler=sampler,
)

register_objective_form_widgets(
study,
widgets=[
ChoiceWidget(
choices=["Good", "So-so", "Bad"],
values=[-1, 0, 1],
),
],
)

n_batch = 4
while True:
running_trials = study.get_trials(deepcopy=False, states=(TrialState.RUNNING,))
if len(running_trials) >= n_batch:
break
study.ask()

return storage


@pytest.fixture
def storage() -> optuna.storages.InMemoryStorage:
storage = make_test_storage()
return storage


@pytest.fixture
def server_url(request: pytest.FixtureRequest, storage: optuna.storages.InMemoryStorage) -> str:
return make_test_server(request, storage)


def test_preferential_optimization(
page: Page,
storage: optuna.storages.InMemoryStorage,
server_url: str,
) -> None:
summaries = optuna.get_all_study_summaries(storage)
study_id = summaries[0]._study_id
url = f"{server_url}/studies/{study_id}/trials"

page.goto(url)

# Confirm that the trial list page is displayed.
expect(page.get_by_role("heading").filter(has_text=re.compile("Trial"))).to_contain_text(
"Trial 0 (trial_id=0)"
)
page.get_by_label("Filter").click()
# Confirm that all trials are running.
expect(
page.get_by_text("Complete (0)Pruned (0)Fail (0)Running (4)Waiting (0)")
).to_be_visible()
page.locator(".MuiBackdrop-root").click()

# Confirm that the trial detail page is displayed.
expect(page.get_by_role("heading").filter(has_text=re.compile("Trial"))).to_contain_text(
"Trial 0 (trial_id=0)"
)
# This trial is running.
expect(page.get_by_text("Running", exact=True).nth(4)).to_be_visible()
page.get_by_label("Bad").check()
page.get_by_role("button", name="Submit").click()
# This trial is completed and is the best trial.
expect(page.get_by_text("Complete").nth(1)).to_be_visible()
expect(page.get_by_text("Best Trial").nth(1)).to_be_visible()

# Move the next trial page.
page.get_by_role("button", name="Trial 1 Running").click()
# Confirm that the trial detail page is displayed.
expect(page.get_by_role("heading").filter(has_text=re.compile("Trial"))).to_contain_text(
"Trial 1 (trial_id=1)"
)
# This trial is running.
expect(page.get_by_text("Running", exact=True).nth(3)).to_be_visible()
page.get_by_label("So-so").check()
page.get_by_role("button", name="Submit").click()
# This trial is completed and is the best trial.
expect(page.get_by_text("Complete").nth(2)).to_be_visible()
expect(page.get_by_text("Best Trial").nth(1)).to_be_visible()

# Move the next trial page.
page.get_by_role("button", name="Trial 2 Running").click()
# Confirm that the trial detail page is displayed.
expect(page.get_by_role("heading").filter(has_text=re.compile("Trial"))).to_contain_text(
"Trial 2 (trial_id=2)"
)
# This trial is running.
expect(page.get_by_text("Running", exact=True).nth(2)).to_be_visible()
page.get_by_label("Good").check()
page.get_by_role("button", name="Submit").click()
# This trial is completed and is the best trial.
expect(page.get_by_text("Complete").nth(3)).to_be_visible()
expect(page.get_by_text("Best Trial").nth(1)).to_be_visible()

# Move the next trial page.
page.get_by_role("button", name="Trial 3 Running").click()
# Confirm that the trial detail page is displayed.
expect(page.get_by_role("heading").filter(has_text=re.compile("Trial"))).to_contain_text(
"Trial 3 (trial_id=3)"
)
# This trial is running.
expect(page.get_by_text("Running", exact=True).nth(1)).to_be_visible()
page.get_by_role("button", name="Fail Trial").click()
# This trial is failed.
expect(page.get_by_text("Fail").nth(1)).to_be_visible()
4 changes: 2 additions & 2 deletions optuna_dashboard/_cached_extra_study_property.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import copy
import numbers
import threading
from typing import List
from typing import Optional
Expand Down Expand Up @@ -85,9 +86,8 @@ def update(self, trials: list[FrozenTrial]) -> None:
self._cursor = next_cursor

def _update_user_attrs(self, trial: FrozenTrial) -> None:
# TODO(c-bata): Support numpy-specific number types.
current_user_attrs = {
k: not isinstance(v, bool) and isinstance(v, (int, float))
k: not isinstance(v, bool) and isinstance(v, numbers.Real)
for k, v in trial.user_attrs.items()
}
for attr_name, current_is_sortable in current_user_attrs.items():
Expand Down
3 changes: 3 additions & 0 deletions optuna_dashboard/_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from datetime import datetime
import json
import numbers
from typing import Any
from typing import TYPE_CHECKING
from typing import Union
Expand Down Expand Up @@ -104,6 +105,8 @@ def serialize_attrs(attrs: dict[str, Any]) -> list[Attribute]:
value = "<binary object>"
elif isinstance(v, str):
value = v
elif isinstance(v, numbers.Real):
value = str(v)
else:
value = json.dumps(v)
value = value[:MAX_ATTR_LENGTH] if len(value) > MAX_ATTR_LENGTH else value
Expand Down
60 changes: 54 additions & 6 deletions optuna_dashboard/artifact/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def proxy_trial_artifact(

@app.post("/api/artifacts/<study_id:int>/<trial_id:int>")
@json_api_view
def upload_artifact_api(study_id: int, trial_id: int) -> dict[str, Any]:
def upload_trial_artifact_api(study_id: int, trial_id: int) -> dict[str, Any]:
trial = storage.get_trial(trial_id)
if trial is None:
response.status = 400
Expand Down Expand Up @@ -144,17 +144,50 @@ def upload_artifact_api(study_id: int, trial_id: int) -> dict[str, Any]:
"artifacts": list_trial_artifacts(storage.get_study_system_attrs(study_id), trial),
}

@app.post("/api/artifacts/<study_id:int>")
@json_api_view
def upload_study_artifact_api(study_id: int) -> dict[str, Any]:
if artifact_store is None:
response.status = 400 # Bad Request
return {"reason": "Cannot access to the artifacts."}
file = request.json.get("file")
if file is None:
response.status = 400
return {"reason": "Please specify the 'file' key."}

_, data = parse_data_uri(file)
filename = request.json.get("filename", "")
artifact_id = str(uuid.uuid4())
artifact_store.write(artifact_id, io.BytesIO(data))

mimetype, encoding = mimetypes.guess_type(filename)
artifact = {
"artifact_id": artifact_id,
"filename": filename,
"mimetype": mimetype or DEFAULT_MIME_TYPE,
"encoding": encoding,
}
attr_key = ARTIFACTS_ATTR_PREFIX + artifact_id
storage.set_study_system_attr(study_id, attr_key, json.dumps(artifact))

response.status = 201

return {
"artifact_id": artifact_id,
"artifacts": list_study_artifacts(storage.get_study_system_attrs(study_id)),
}

@app.delete("/api/artifacts/<study_id:int>/<trial_id:int>/<artifact_id:re:[0-9a-fA-F-]+>")
@json_api_view
def delete_artifact(study_id: int, trial_id: int, artifact_id: str) -> dict[str, Any]:
def delete_trial_artifact(study_id: int, trial_id: int, artifact_id: str) -> dict[str, Any]:
if artifact_store is None:
response.status = 400 # Bad Request
return {"reason": "Cannot access to the artifacts."}
artifact_store.remove(artifact_id)

# The artifact's metadata is stored in one of the following two locations:
storage.set_study_system_attr(
study_id, _dashboard_trial_artifact_prefix(trial_id) + artifact_id, json.dumps(None)
study_id, _dashboard_artifact_prefix(trial_id) + artifact_id, json.dumps(None)
)
storage.set_trial_system_attr(
trial_id, ARTIFACTS_ATTR_PREFIX + artifact_id, json.dumps(None)
Expand All @@ -163,6 +196,21 @@ def delete_artifact(study_id: int, trial_id: int, artifact_id: str) -> dict[str,
response.status = 204
return {}

@app.delete("/api/artifacts/<study_id:int>/<artifact_id:re:[0-9a-fA-F-]+>")
@json_api_view
def delete_study_artifact(study_id: int, artifact_id: str) -> dict[str, Any]:
if artifact_store is None:
response.status = 400 # Bad Request
return {"reason": "Cannot access to the artifacts."}
artifact_store.remove(artifact_id)

storage.set_study_system_attr(
study_id, ARTIFACTS_ATTR_PREFIX + artifact_id, json.dumps(None)
)

response.status = 204
return {}


def upload_artifact(
backend: ArtifactBackend,
Expand Down Expand Up @@ -220,7 +268,7 @@ def objective(trial: optuna.Trial) -> float:
return artifact_id


def _dashboard_trial_artifact_prefix(trial_id: int) -> str:
def _dashboard_artifact_prefix(trial_id: int) -> str:
return DASHBOARD_ARTIFACTS_ATTR_PREFIX + f"{trial_id}:"


Expand All @@ -240,7 +288,7 @@ def get_trial_artifact_meta(
) -> Optional[ArtifactMeta]:
# Search study_system_attrs due to backward compatibility.
study_system_attrs = storage.get_study_system_attrs(study_id)
attr_key = _dashboard_trial_artifact_prefix(trial_id=trial_id) + artifact_id
attr_key = _dashboard_artifact_prefix(trial_id=trial_id) + artifact_id
artifact_meta = study_system_attrs.get(attr_key)
if artifact_meta is not None:
return json.loads(artifact_meta)
Expand Down Expand Up @@ -284,7 +332,7 @@ def list_trial_artifacts(
dashboard_artifact_metas = [
json.loads(value)
for key, value in study_system_attrs.items()
if key.startswith(_dashboard_trial_artifact_prefix(trial._trial_id))
if key.startswith(_dashboard_artifact_prefix(trial._trial_id))
]

# Collect ArtifactMeta from trial_system_attrs. Note that artifacts uploaded via
Expand Down
Loading

0 comments on commit fc77515

Please sign in to comment.