Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions roboflow/adapters/rfapi.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import os
import urllib
from typing import Optional
from typing import List, Optional

import requests
from requests.exceptions import RequestException
Expand Down Expand Up @@ -56,7 +56,7 @@ def upload_image(
hosted_image: bool = False,
split: str = "train",
batch_name: str = DEFAULT_BATCH_NAME,
tag_names: list = [],
tag_names: Optional[List[str]] = None,
sequence_number: Optional[int] = None,
sequence_size: Optional[int] = None,
**kwargs,
Expand All @@ -71,6 +71,8 @@ def upload_image(
"""

coalesced_batch_name = batch_name or DEFAULT_BATCH_NAME
if tag_names is None:
tag_names = []

# If image is not a hosted image
if not hosted_image:
Expand Down
41 changes: 29 additions & 12 deletions roboflow/core/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,13 +256,7 @@ def generate_version(self, settings):

def train(
self,
new_version_settings={
"preprocessing": {
"auto-orient": True,
"resize": {"width": 640, "height": 640, "format": "Stretch to"},
},
"augmentation": {},
},
new_version_settings: Optional[Dict] = None,
speed=None,
checkpoint=None,
plot_in_notebook=False,
Expand Down Expand Up @@ -294,6 +288,15 @@ def train(
>>> version.train()
""" # noqa: E501 // docs

if new_version_settings is None:
new_version_settings = {
"preprocessing": {
"auto-orient": True,
"resize": {"width": 640, "height": 640, "format": "Stretch to"},
},
"augmentation": {},
}

new_version = self.generate_version(settings=new_version_settings)
new_version = self.version(new_version)
new_model = new_version.train(speed=speed, checkpoint=checkpoint, plot_in_notebook=plot_in_notebook)
Expand Down Expand Up @@ -384,7 +387,7 @@ def upload(
split: str = "train",
num_retry_uploads: int = 0,
batch_name: Optional[str] = None,
tag_names: list = [],
tag_names: Optional[List[str]] = None,
is_prediction: bool = False,
**kwargs,
):
Expand Down Expand Up @@ -413,6 +416,9 @@ def upload(
>>> project.upload(image_path="YOUR_IMAGE.jpg")
""" # noqa: E501 // docs

if tag_names is None:
tag_names = []

is_hosted = image_path.startswith("http://") or image_path.startswith("https://")

is_file = os.path.isfile(image_path) or is_hosted
Expand Down Expand Up @@ -476,13 +482,16 @@ def upload_image(
split="train",
num_retry_uploads=0,
batch_name=None,
tag_names=[],
tag_names: Optional[List[str]] = None,
sequence_number=None,
sequence_size=None,
**kwargs,
):
project_url = self.id.rsplit("/")[1]

if tag_names is None:
tag_names = []

t0 = time.time()
upload_retry_attempts = 0
retry = Retry(num_retry_uploads, ImageUploadError)
Expand Down Expand Up @@ -557,13 +566,15 @@ def single_upload(
split="train",
num_retry_uploads=0,
batch_name=None,
tag_names=[],
tag_names: Optional[List[str]] = None,
is_prediction: bool = False,
annotation_overwrite=False,
sequence_number=None,
sequence_size=None,
**kwargs,
):
if tag_names is None:
tag_names = []
if image_path and image_id:
raise Exception("You can't pass both image_id and image_path")
if not (image_path or image_id):
Expand Down Expand Up @@ -641,7 +652,7 @@ def search(
in_dataset: Optional[str] = None,
batch: bool = False,
batch_id: Optional[str] = None,
fields: list = ["id", "created", "name", "labels"],
fields: Optional[List[str]] = None,
):
"""
Search for images in a project.
Expand Down Expand Up @@ -670,6 +681,9 @@ def search(

>>> results = project.search(query="cat", limit=10)
""" # noqa: E501 // docs
if fields is None:
fields = ["id", "created", "name", "labels"]

payload: Dict[str, Union[str, int, List[str]]] = {}

if like_image is not None:
Expand Down Expand Up @@ -719,7 +733,7 @@ def search_all(
in_dataset: Optional[str] = None,
batch: bool = False,
batch_id: Optional[str] = None,
fields: list = ["id", "created"],
fields: Optional[List[str]] = None,
):
"""
Create a paginated list of search results for use in searching the images in a project.
Expand Down Expand Up @@ -752,6 +766,9 @@ def search_all(

>>> print(result)
""" # noqa: E501 // docs
if fields is None:
fields = ["id", "created"]

while True:
data = self.search(
like_image=like_image,
Expand Down
11 changes: 8 additions & 3 deletions roboflow/core/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import json
import os
import sys
from typing import Any, List
from typing import Any, Dict, List, Optional

import requests
from PIL import Image
Expand Down Expand Up @@ -433,9 +433,9 @@ def active_learning(
self,
raw_data_location: str = "",
raw_data_extension: str = "",
inference_endpoint: list = [],
inference_endpoint: Optional[List[str]] = None,
upload_destination: str = "",
conditionals: dict = {},
conditionals: Optional[Dict] = None,
use_localhost: bool = False,
local_server="http://localhost:9001/",
) -> Any:
Expand All @@ -449,6 +449,11 @@ def active_learning(
use_localhost: (bool) = determines if local http format used or remote endpoint
local_server: (str) = local http address for inference server, use_localhost must be True for this to be used
""" # noqa: E501 // docs
if inference_endpoint is None:
inference_endpoint = []
if conditionals is None:
conditionals = {}

import numpy as np

prediction_results = []
Expand Down
7 changes: 5 additions & 2 deletions roboflow/models/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import time
import urllib
from typing import Optional, Tuple
from typing import List, Optional, Tuple
from urllib.parse import urljoin

import requests
Expand Down Expand Up @@ -137,7 +137,7 @@ def predict_video(
self,
video_path: str,
fps: int = 5,
additional_models: list = [],
additional_models: Optional[List[str]] = None,
prediction_type: str = "batch-video",
) -> Tuple[str, str, Optional[str]]:
"""
Expand Down Expand Up @@ -170,6 +170,9 @@ def predict_video(
if fps > 120:
raise Exception("FPS must be less than or equal to 120.")

if additional_models is None:
additional_models = []

for model in additional_models:
if model not in SUPPORTED_ADDITIONAL_MODELS:
raise Exception(f"Model {model} is not supported for video inference.")
Expand Down