init
This commit is contained in:
9
tools/pretraining_data_builder/.gitignore
vendored
Normal file
9
tools/pretraining_data_builder/.gitignore
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
.venv/
|
||||
.env
|
||||
**/*.pyc
|
||||
**/.pytest_cache
|
||||
search_data.json
|
||||
.idea
|
||||
*.zip
|
||||
*.jp2
|
||||
.ruff_cache
|
||||
36
tools/pretraining_data_builder/README.md
Normal file
36
tools/pretraining_data_builder/README.md
Normal file
@@ -0,0 +1,36 @@
|
||||
# Pretraining Data Builder
|
||||
This code is for building pretraining data for the self-supervised learning of SkySense++.
|
||||
|
||||
## Install
|
||||
Prepare the environment:
|
||||
```
|
||||
conda create -n data_builder python=3.12
|
||||
conda activate data_builder
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
Download pretraining data list in lmdb format from [Zenodo](https://zenodo.org/records/14994430)
|
||||
|
||||
## Download Data
|
||||
```
|
||||
python -m rsi_download --username <username> --password <password> --api_key <api_key> <X> <Y> <Z> <date_min> <date_max>
|
||||
```
|
||||
Notes:
|
||||
1. `username` and `password` can be created in the [Copernicus Data Space Ecosystem](https://data.copernicus.eu/cdsapp/#!/home),
|
||||
`api_key` can be created in the [Maxar](https://ard.maxar.com/docs/about/).
|
||||
2. `X` `Y` `Z` are coordinates in the Web Mercator coordinate system.
|
||||
3. `date_min` and `date_max` are in the format of `YYYY-MM`.
|
||||
|
||||
## Process Data
|
||||
```
|
||||
python -m rsi_process --platform <platform> --fn_img path/to/image.zip --save_dir output_<platform>/
|
||||
```
|
||||
Notes:
|
||||
1. `platform` can be `s1`, `s2`, `wv`.
|
||||
2. `fn_img` is the path to the downloaded zip file.
|
||||
3. `save_dir` is the directory to save the processed data.
|
||||
|
||||
## Automatic Script
|
||||
```
|
||||
sh run_data_builder.sh
|
||||
```
|
||||
This script will first read the pretraining list, then download the data according to the list, and proceed them automatically.
|
||||
17
tools/pretraining_data_builder/requirements.txt
Normal file
17
tools/pretraining_data_builder/requirements.txt
Normal file
@@ -0,0 +1,17 @@
|
||||
httpx>=0.27.2
|
||||
python-dotenv>=1.0.0
|
||||
orjson>=3.9.10
|
||||
rich>=13.7.0
|
||||
click>=8.1.7
|
||||
msgspec>=0.18.4
|
||||
asyncclick>=8.1.3.4
|
||||
numpy
|
||||
gdal
|
||||
pyproj
|
||||
mercantile
|
||||
Pillow
|
||||
shapely
|
||||
imageio
|
||||
geopandas
|
||||
pyresample
|
||||
lmdb
|
||||
76
tools/pretraining_data_builder/rsi_download/__main__.py
Normal file
76
tools/pretraining_data_builder/rsi_download/__main__.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import click
|
||||
from rsi_download.download_async import download_core
|
||||
import asyncio
|
||||
|
||||
@click.command()
|
||||
@click.argument("x", type=click.STRING)
|
||||
@click.argument("y", type=click.STRING)
|
||||
@click.argument("z", type=click.STRING)
|
||||
@click.argument("date_min", type=click.STRING)
|
||||
@click.argument("date_max", type=click.STRING)
|
||||
@click.option(
|
||||
"--username",
|
||||
"-u",
|
||||
type=click.STRING,
|
||||
help="Username for Copernicus Data Space Ecosystem",
|
||||
)
|
||||
@click.option(
|
||||
"--password", "-p", prompt=True, hide_input=True, confirmation_prompt=False
|
||||
)
|
||||
@click.option(
|
||||
"--api_key", "-k", prompt=True, hide_input=True, confirmation_prompt=False
|
||||
)
|
||||
@click.option(
|
||||
"--max",
|
||||
"-m",
|
||||
"max_",
|
||||
default=100,
|
||||
type=click.INT,
|
||||
show_default=True,
|
||||
help="maximum number of results returned",
|
||||
)
|
||||
@click.option(
|
||||
"--cloud-coverage",
|
||||
"-c",
|
||||
"cloud_coverage",
|
||||
default=10.00,
|
||||
type=click.FLOAT,
|
||||
show_default=True,
|
||||
help="Get only results with a cloud coverage percentage less then the argument given.",
|
||||
)
|
||||
|
||||
@click.option(
|
||||
"--platform-name",
|
||||
"-n",
|
||||
"platform_name",
|
||||
default="S2",
|
||||
type=click.Choice(["S2", "S1", "WV3"]),
|
||||
show_default=True,
|
||||
help="Get only results with a platform name.",
|
||||
)
|
||||
|
||||
@click.option(
|
||||
"--debug",
|
||||
default=False,
|
||||
is_flag=True,
|
||||
type=click.BOOL,
|
||||
show_default=True,
|
||||
help="Debug the http requests and extra debug logging",
|
||||
)
|
||||
@click.option(
|
||||
"--tci",
|
||||
default=False,
|
||||
is_flag=True,
|
||||
type=click.BOOL,
|
||||
show_default=True,
|
||||
help="Download only True Color Image (TCI)",
|
||||
)
|
||||
|
||||
def main(x, y, z, date_min, date_max, username, password, api_key, max_, cloud_coverage, debug, tci, platform_name):
|
||||
return asyncio.run(download_core(x, y, z, date_min, date_max, username, password, api_key, max_, cloud_coverage, debug, tci, platform_name))
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
print("\n程序已终止")
|
||||
36
tools/pretraining_data_builder/rsi_download/auth.py
Normal file
36
tools/pretraining_data_builder/rsi_download/auth.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import httpx
|
||||
import msgspec
|
||||
|
||||
|
||||
class CDSETokens(msgspec.Struct):
|
||||
"""Copernicus Data Space Ecosystem Tokens"""
|
||||
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
expires_in: int
|
||||
refresh_expires_in: int
|
||||
token_type: str
|
||||
not_before_policy: int = msgspec.field(name="not-before-policy")
|
||||
session_state: str
|
||||
scope: str
|
||||
|
||||
|
||||
def get_access_token(username: str, password: str) -> CDSETokens:
|
||||
data = {
|
||||
"client_id": "cdse-public",
|
||||
"username": username,
|
||||
"password": password,
|
||||
"grant_type": "password",
|
||||
}
|
||||
try:
|
||||
with httpx.Client() as client:
|
||||
r = client.post(
|
||||
"https://identity.dataspace.copernicus.eu/auth/realms/CDSE/protocol/openid-connect/token",
|
||||
data=data,
|
||||
)
|
||||
r.raise_for_status()
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"Access token creation failed: {e}. Reponse from the server was: {r.json()}"
|
||||
)
|
||||
return msgspec.json.decode(r.content, type=CDSETokens)
|
||||
133
tools/pretraining_data_builder/rsi_download/cli.py
Normal file
133
tools/pretraining_data_builder/rsi_download/cli.py
Normal file
@@ -0,0 +1,133 @@
|
||||
from typing import Tuple, List
|
||||
|
||||
from rich.table import Table
|
||||
from rich.console import Console
|
||||
import re
|
||||
import msgspec
|
||||
from datetime import datetime
|
||||
from rich.progress import (
|
||||
BarColumn,
|
||||
DownloadColumn,
|
||||
Progress,
|
||||
TextColumn,
|
||||
TimeRemainingColumn,
|
||||
TransferSpeedColumn,
|
||||
)
|
||||
from rsi_download.exceptions import InvalidWktPointArgument, InvalidDateRangeArgument
|
||||
from rsi_download.download.search import SearchContent, SearchResult
|
||||
|
||||
|
||||
class Preview(msgspec.Struct):
|
||||
id: str
|
||||
productid: str
|
||||
url: str
|
||||
origin_date: str
|
||||
name: str
|
||||
|
||||
|
||||
progress = Progress(
|
||||
TextColumn("[bold blue]{task.fields[filename]}", justify="right"),
|
||||
BarColumn(bar_width=None),
|
||||
"[progress.percentage]{task.percentage:>3.1f}%",
|
||||
"•",
|
||||
DownloadColumn(),
|
||||
"•",
|
||||
TransferSpeedColumn(),
|
||||
"•",
|
||||
TimeRemainingColumn(),
|
||||
)
|
||||
|
||||
|
||||
# "2022-05-03T00:00:00.000Z"
|
||||
ESA_DATE_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ"
|
||||
|
||||
|
||||
def convert_to_timestamp(datestring="", dateformat="%d-%m-%Y %H:%M:%S") -> str:
|
||||
if len(datestring) > 10:
|
||||
source = datetime.strptime(datestring, dateformat)
|
||||
else:
|
||||
source = datetime.strptime(datestring, "%d-%m-%Y")
|
||||
return source.strftime(ESA_DATE_FORMAT)
|
||||
|
||||
|
||||
def daterange_to_timestamp(daterange: str) -> Tuple[str, str]:
|
||||
if "," not in daterange:
|
||||
raise InvalidDateRangeArgument(
|
||||
f'Give a valid daterange string. for example: "11-08-2023 00:00:00,11-09-2023 00:00:00" \n Daterange received: {daterange}'
|
||||
)
|
||||
gt, lt = daterange.split(",")
|
||||
try:
|
||||
time_gt = convert_to_timestamp(datestring=gt)
|
||||
except ValueError:
|
||||
raise InvalidDateRangeArgument(
|
||||
f"Invalid dateformat encountered for time_gt: {gt}. Dateformat expected: %d-%m-%Y or %d-%m-%Y %H:%M:%S"
|
||||
)
|
||||
try:
|
||||
time_lt = convert_to_timestamp(datestring=lt)
|
||||
except ValueError:
|
||||
raise InvalidDateRangeArgument(
|
||||
f"Invalid dateformat encountered for time_lt: {lt}. Dateformat expected: %d-%m-%Y or %d-%m-%Y %H:%M:%S"
|
||||
)
|
||||
return time_gt, time_lt
|
||||
|
||||
|
||||
def wkt_to_point(wktstring: str) -> Tuple[float, ...]:
|
||||
nums = re.findall(r"[-+]?\d*\.\d+|\d+", wktstring)
|
||||
if len(nums) != 2:
|
||||
raise InvalidWktPointArgument(
|
||||
f"Give a valid WKT string. for example: POINT(-9.1372 38.7000). WKT received: {wktstring}"
|
||||
)
|
||||
return tuple(float(n) for n in nums)
|
||||
|
||||
|
||||
def show_preview_urls(search_json: SearchContent, platform_name: str) -> List[Preview]:
|
||||
"""
|
||||
Show a list of preview urls for downloading in the terminal
|
||||
|
||||
:param search_json: SearchContent object
|
||||
"""
|
||||
# print(search_json.value)
|
||||
preview_urls = [
|
||||
Preview(
|
||||
id=str(i),
|
||||
productid=v.id,
|
||||
url=v.assets[0].download_link,
|
||||
origin_date=v.content_date.start,
|
||||
name=v.name,
|
||||
)
|
||||
for i, v in enumerate(search_json.value)
|
||||
]
|
||||
table = Table(title=f"RSI Preview Url's")
|
||||
table.add_column("ID", justify="left", style="magenta")
|
||||
table.add_column("Acquisition Time", justify="left", style="blue")
|
||||
table.add_column("Name", justify="left", style="magenta")
|
||||
|
||||
for entry in preview_urls:
|
||||
table.add_row(
|
||||
entry.id,
|
||||
f'[link={entry.url.replace("(", "%28").replace(")", "%29")}]{entry.origin_date}[/link]',
|
||||
entry.name,
|
||||
)
|
||||
|
||||
console = Console()
|
||||
console.print(table)
|
||||
return preview_urls
|
||||
|
||||
|
||||
def get_selected_products(
|
||||
search_json: SearchContent, preview_urls: List[Preview], product_ids: str
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
Return the selected items from the search_json by the preview url id.
|
||||
|
||||
:param search_json: SearchContent
|
||||
:param preview_urls: List[Preview]
|
||||
:param product_ids: string of preview ids
|
||||
:return: List[SearchResult]
|
||||
"""
|
||||
download_product_ids = [
|
||||
item.productid
|
||||
for item in preview_urls
|
||||
if item.id in [str(n) for n in product_ids]
|
||||
]
|
||||
return [x for x in search_json.value if x.id in download_product_ids]
|
||||
@@ -0,0 +1,97 @@
|
||||
import asyncio
|
||||
from typing import List
|
||||
import signal
|
||||
import httpx
|
||||
from rich.progress import TaskID, Event
|
||||
from rsi_download.cli import progress
|
||||
from rsi_download.download.search import SearchResult
|
||||
from rsi_download.cli import Preview
|
||||
import os
|
||||
|
||||
done_event = Event()
|
||||
|
||||
|
||||
def handle_sigint(signum, frame):
|
||||
done_event.set()
|
||||
|
||||
|
||||
signal.signal(signal.SIGINT, handle_sigint)
|
||||
|
||||
|
||||
async def download_tci_products_data(
|
||||
task_id: TaskID, product: SearchResult, access_token: str, mm_band: str = "R10m"
|
||||
):
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
progress.start_task(task_id)
|
||||
async with httpx.AsyncClient() as client:
|
||||
client.headers.update(headers)
|
||||
# create the tci image url
|
||||
granule_url = f"https://zipper.dataspace.copernicus.eu/odata/v1/Products({product.id})/Nodes({product.name})/Nodes(GRANULE)/Nodes"
|
||||
granule_resp = await client.get(
|
||||
f"{granule_url}", follow_redirects=True, headers=headers
|
||||
)
|
||||
granule_folder = granule_resp.json()
|
||||
img_data_url = f"{granule_url}({granule_folder['result'][0]['Name']})/Nodes(IMG_DATA)/Nodes({mm_band})/Nodes"
|
||||
img_data_resp = await client.get(img_data_url, follow_redirects=True)
|
||||
img_data = img_data_resp.json()
|
||||
tci_name = [img["Name"] for img in img_data["result"] if "TCI" in img["Name"]][
|
||||
0
|
||||
]
|
||||
tci_url = f"{img_data_url}({tci_name})/$value"
|
||||
async with client.stream(
|
||||
method="GET",
|
||||
url=tci_url,
|
||||
headers=headers,
|
||||
) as response:
|
||||
progress.update(task_id, total=int(response.headers["Content-length"]))
|
||||
with open(f"{tci_name}", "wb") as file:
|
||||
progress.start_task(task_id)
|
||||
async for chunk in response.aiter_bytes():
|
||||
if chunk:
|
||||
file.write(chunk)
|
||||
progress.update(task_id, advance=len(chunk))
|
||||
if done_event.is_set():
|
||||
return
|
||||
progress.console.log(f"Downloaded {tci_name}")
|
||||
|
||||
|
||||
async def download_data(task_id: TaskID, product: SearchResult, preview: Preview, access_token: str):
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
async with httpx.AsyncClient() as client:
|
||||
client.headers.update(headers)
|
||||
async with client.stream(
|
||||
"GET",
|
||||
url=f"https://zipper.dataspace.copernicus.eu/odata/v1/Products({product.id})/$value",
|
||||
headers=headers,
|
||||
) as response:
|
||||
progress.update(task_id, total=int(response.headers["Content-length"]))
|
||||
with open(f"out_raw/{preview.name.replace('.SAFE', '.zip')}", "wb") as file:
|
||||
progress.start_task(task_id)
|
||||
async for chunk in response.aiter_bytes():
|
||||
if chunk:
|
||||
file.write(chunk)
|
||||
progress.update(task_id, advance=len(chunk))
|
||||
if done_event.is_set():
|
||||
return
|
||||
progress.console.log(f"Downloaded {preview.name.replace('.SAFE', '.zip')}")
|
||||
|
||||
async def download_products_data(
|
||||
products: List[SearchResult], previews: List[Preview], access_token: str, tci_only: bool = False
|
||||
):
|
||||
with progress:
|
||||
download_tasks = []
|
||||
for product, preview in zip(products, previews):
|
||||
task_id = progress.add_task(
|
||||
f"{preview.name.replace('.SAFE', '.zip')}",
|
||||
filename=f"{preview.name.replace('.SAFE', '.zip')}",
|
||||
start=False,
|
||||
)
|
||||
if tci_only:
|
||||
download_tasks.append(
|
||||
download_tci_products_data(task_id, product, access_token)
|
||||
)
|
||||
else:
|
||||
download_tasks.append(download_data(task_id, product, preview, access_token))
|
||||
# os.rename(f"product-{product.id}.zip", f"{preview.name.replace('.SAFE', '.zip')}")
|
||||
await asyncio.gather(*download_tasks)
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
|
||||
import msgspec
|
||||
import httpx
|
||||
|
||||
from rsi_download.exceptions import SearchException
|
||||
from rsi_download.geo.geo_types import GeoJsonPolygon
|
||||
|
||||
ESA_SEARCH_URL = r"https://catalogue.dataspace.copernicus.eu/odata/v1/Products"
|
||||
|
||||
|
||||
class ContentData(msgspec.Struct, rename="pascal"):
|
||||
"""Odata search result start and end date"""
|
||||
|
||||
start: str
|
||||
end: str
|
||||
|
||||
|
||||
class Asset(msgspec.Struct, rename="pascal"):
|
||||
"""Odata search Asset"""
|
||||
|
||||
type_: str = msgspec.field(name="Type")
|
||||
id: str
|
||||
download_link: str
|
||||
s3_path: str
|
||||
|
||||
|
||||
class SearchResult(msgspec.Struct, rename="pascal"):
|
||||
"""Odata search Result"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
content_length: int
|
||||
origin_date: str
|
||||
s3_path: str
|
||||
content_date: ContentData
|
||||
geo_footprint: GeoJsonPolygon
|
||||
assets: List[Asset]
|
||||
|
||||
|
||||
class SearchContent(msgspec.Struct):
|
||||
value: List[SearchResult]
|
||||
next_link: str | None = msgspec.field(default=None, name="@odata.nextLink")
|
||||
|
||||
|
||||
async def search_odata(
|
||||
long: float,
|
||||
lat: float,
|
||||
cloud_coverage: float,
|
||||
time_lt: str,
|
||||
time_gt: str,
|
||||
max_: int,
|
||||
platform_name: str,
|
||||
) -> SearchContent:
|
||||
# filter voor zoeken op cloudCover, Productype en orbitDirection.
|
||||
# lt = less then
|
||||
# eq = equal to
|
||||
# gt = greater then
|
||||
# sentinel-2
|
||||
if platform_name == "S2":
|
||||
search_filter = f"OData.CSC.Intersects(area=geography'SRID=4326;POINT ({long:.4f} {lat:.4f})') and Attributes/OData.CSC.DoubleAttribute/any(att:att/Name eq 'cloudCover' and att/OData.CSC.DoubleAttribute/Value lt {cloud_coverage:.2f}) and Attributes/OData.CSC.StringAttribute/any(att:att/Name eq 'productType' and att/OData.CSC.StringAttribute/Value eq 'S2MSI2A') and ContentDate/Start gt {time_gt} and ContentDate/Start lt {time_lt}"
|
||||
elif platform_name == "S1":
|
||||
search_filter = f"OData.CSC.Intersects(area=geography'SRID=4326;POINT ({long:.4f} {lat:.4f})') and Attributes/OData.CSC.StringAttribute/any(att:att/Name eq 'productType' and att/OData.CSC.StringAttribute/Value eq 'IW_GRDH_1S') and ContentDate/Start gt {time_gt} and ContentDate/Start lt {time_lt}"
|
||||
elif platform_name == "WV3":
|
||||
search_filter = f"OData.CSC.Intersects(area=geography'SRID=4326;POINT ({long:.4f} {lat:.4f})') and Attributes/OData.CSC.StringAttribute/any(att:att/Name eq 'platformName' and att/OData.CSC.StringAttribute/Value eq 'WorldView-3') and ContentDate/Start gt {time_gt} and ContentDate/Start lt {time_lt}"
|
||||
else:
|
||||
raise ValueError(f"Invalid platform name: {platform_name}")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(
|
||||
url=f"{ESA_SEARCH_URL}?$filter={search_filter}&$top={max_}&$expand=Assets",
|
||||
timeout=60,
|
||||
)
|
||||
if not r.status_code == 200:
|
||||
raise SearchException(f"Error getting data: {r.text}")
|
||||
return msgspec.json.decode(r.content, type=SearchContent)
|
||||
100
tools/pretraining_data_builder/rsi_download/download_async.py
Normal file
100
tools/pretraining_data_builder/rsi_download/download_async.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
import msgspec
|
||||
import asyncio
|
||||
from rich import print
|
||||
from rsi_download.auth import get_access_token
|
||||
from rsi_download.download.product import download_products_data
|
||||
from rsi_download.cli import (
|
||||
show_preview_urls,
|
||||
Preview,
|
||||
get_selected_products,
|
||||
)
|
||||
from rsi_download.download.search import search_odata
|
||||
import math
|
||||
|
||||
|
||||
|
||||
async def download_core(
|
||||
x: str,
|
||||
y: str,
|
||||
z: str,
|
||||
date_min: str,
|
||||
date_max: str,
|
||||
username: str,
|
||||
password: str,
|
||||
api_key: str = None,
|
||||
max_: int = 100,
|
||||
cloud_coverage: float = 20.0,
|
||||
debug: bool = False,
|
||||
tci: bool = False,
|
||||
platform_name: str = "S2",
|
||||
):
|
||||
"""
|
||||
X tile x coordinate
|
||||
Y tile y coordinate
|
||||
Z zoom level
|
||||
DATE_MIN start date in format YYYYMM
|
||||
DATE_MAX end date in format YYYYMM
|
||||
"""
|
||||
lat, long = tile_to_latlon(float(x), float(y), float(z))
|
||||
time_gt = f"{date_min[:4]}-{date_min[4:6]}-01T00:00:00.000Z"
|
||||
year = int(date_max[:4])
|
||||
month = int(date_max[4:])
|
||||
if month == 12:
|
||||
next_year = year + 1
|
||||
next_month = 1
|
||||
else:
|
||||
next_year = year
|
||||
next_month = month + 1
|
||||
time_lt = f"{next_year}-{next_month:02d}-01T00:00:00.000Z"
|
||||
|
||||
print(f"coordinates: lat: {lat:.4f}, long: {long:.4f}")
|
||||
print(f"maximum results: {max_}")
|
||||
print(f"cloud coverage percentage less then: {cloud_coverage:.2f}")
|
||||
print(f"time_gt: {time_gt}, time_lt: {time_lt}")
|
||||
search_data = await search_odata(long, lat, cloud_coverage, time_lt, time_gt, max_, platform_name)
|
||||
if debug:
|
||||
print("DEBUG: Search request data is saved to disk.")
|
||||
with open("search_data.json", "wb") as f:
|
||||
f.write(msgspec.json.encode(search_data))
|
||||
preview_urls: List[Preview] = show_preview_urls(search_data, platform_name)
|
||||
print("start downloading all data ...")
|
||||
products_to_download = get_selected_products(
|
||||
search_json=search_data, preview_urls=preview_urls, product_ids=list(range(len(preview_urls)))
|
||||
)
|
||||
tokens = get_access_token(username, password)
|
||||
|
||||
try:
|
||||
for i, (product, preview) in enumerate(zip(products_to_download, preview_urls)):
|
||||
print(f"[{i+1}/{len(products_to_download)}] downloading {product.id} ...")
|
||||
await asyncio.shield(download_products_data(
|
||||
[product], [preview], tokens.access_token, tci_only=tci
|
||||
))
|
||||
except asyncio.CancelledError:
|
||||
print("\nDownload cancelled, exiting...")
|
||||
return
|
||||
|
||||
def tile_to_latlon(x: int, y: int, z: int, get_center: bool = True) -> Tuple[float, float]:
|
||||
"""
|
||||
Convert XYZ tile coordinates to latitude/longitude
|
||||
|
||||
Args:
|
||||
x: Tile X coordinate
|
||||
y: Tile Y coordinate
|
||||
z: Zoom level
|
||||
get_center: If True, returns the center point coordinates. If False, returns the top-left corner.
|
||||
|
||||
Returns:
|
||||
Tuple of (latitude, longitude)
|
||||
"""
|
||||
n = 2.0 ** z
|
||||
if get_center:
|
||||
x += 0.5
|
||||
y += 0.5
|
||||
|
||||
lon_deg = x / n * 360.0 - 180.0
|
||||
lat_rad = math.atan(math.sinh(math.pi * (1 - 2 * y / n)))
|
||||
lat_deg = math.degrees(lat_rad)
|
||||
return lat_deg, lon_deg
|
||||
|
||||
16
tools/pretraining_data_builder/rsi_download/exceptions.py
Normal file
16
tools/pretraining_data_builder/rsi_download/exceptions.py
Normal file
@@ -0,0 +1,16 @@
|
||||
class InvalidWktPointArgument(Exception):
|
||||
"""Raised when the WKT string is not a valid point"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvalidDateRangeArgument(Exception):
|
||||
"""Raised when the daterange string is not valid"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SearchException(Exception):
|
||||
"""Raised when search endpoint returned a non 200 statuscode"""
|
||||
|
||||
pass
|
||||
13
tools/pretraining_data_builder/rsi_download/geo/geo_types.py
Normal file
13
tools/pretraining_data_builder/rsi_download/geo/geo_types.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from typing import List
|
||||
|
||||
import msgspec
|
||||
|
||||
|
||||
class Coordinate(msgspec.Struct):
|
||||
long: float
|
||||
lat: float
|
||||
|
||||
|
||||
class GeoJsonPolygon(msgspec.Struct):
|
||||
type: str
|
||||
coordinates: List[List[List[float]]]
|
||||
12
tools/pretraining_data_builder/rsi_download/sort.py
Normal file
12
tools/pretraining_data_builder/rsi_download/sort.py
Normal file
@@ -0,0 +1,12 @@
|
||||
def sort_by_cloudcover(search_result):
|
||||
entries = search_result["feed"]["entry"]
|
||||
sorted_entries = []
|
||||
for entry in entries:
|
||||
sorted_entries.append(
|
||||
[
|
||||
float(e["content"])
|
||||
for e in entry["double"]
|
||||
if e["name"] == "cloudcoverpercentage"
|
||||
][0]
|
||||
)
|
||||
return sorted(sorted_entries, key=float)
|
||||
70
tools/pretraining_data_builder/rsi_pipeline/data_builder.py
Normal file
70
tools/pretraining_data_builder/rsi_pipeline/data_builder.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import lmdb
|
||||
import os
|
||||
import json
|
||||
from rich import print
|
||||
from rsi_download.download_async import download_core
|
||||
from rsi_process.adapter import process_adapter
|
||||
import asyncclick as click
|
||||
|
||||
@click.command()
|
||||
@click.argument("lmdb_path", type=click.STRING)
|
||||
async def read_lmdb_file(lmdb_path):
|
||||
"""
|
||||
Read the LMDB file and print all key-value pairs
|
||||
|
||||
Args:
|
||||
lmdb_path: LMDB file path
|
||||
"""
|
||||
if not os.path.exists(lmdb_path):
|
||||
print(f"Error: LMDB path '{lmdb_path}' does not exist")
|
||||
return
|
||||
|
||||
try:
|
||||
print(f"Reading Pretraining List from LMDB file from {lmdb_path}...")
|
||||
env = lmdb.open(lmdb_path, readonly=True)
|
||||
total_length = 0
|
||||
with env.begin() as txn:
|
||||
key = b'length'
|
||||
total_length = int(txn.get(key))
|
||||
print(f"Total length of the Pretraining Data: {total_length:,}")
|
||||
print("Example Data:")
|
||||
for i in range(10):
|
||||
print(txn.get(f"{i}".encode()).decode('utf-8'))
|
||||
for i in range(total_length):
|
||||
key = f"{i}".encode()
|
||||
data = json.loads(txn.get(key).decode('utf-8'))
|
||||
print("*"* 116 + "\n" + f"* Current Data [{i+1} / {total_length}]: {data} *" + "\n" + "*"* 116 )
|
||||
print(f"Downloading: {data}")
|
||||
await download_core(
|
||||
x=data['x'],
|
||||
y=data['y'],
|
||||
z=data['z'],
|
||||
date_min=data['date_min'],
|
||||
date_max=data['date_max'],
|
||||
username=os.getenv("USERNAME"),
|
||||
password=os.getenv("PASSWORD"),
|
||||
cloud_coverage=20.0,
|
||||
tci=False
|
||||
)
|
||||
print('-'* 40)
|
||||
print(f"Processing: {data}")
|
||||
process_list = os.listdir('out_raw/')
|
||||
total_len_process = len(process_list)
|
||||
for fn in process_list:
|
||||
print(f"Processing: {fn} [{i+1} / {total_len_process}]...")
|
||||
process_adapter(
|
||||
fn_img=f'out_raw/{fn}',
|
||||
save_dir='out_processed/',
|
||||
verbose=True,
|
||||
use_gcj02=False
|
||||
)
|
||||
print('-'* 40)
|
||||
print("Done!")
|
||||
|
||||
except lmdb.Error as e:
|
||||
print(f"Error reading LMDB file: {str(e)}")
|
||||
finally:
|
||||
env.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
read_lmdb_file()
|
||||
19
tools/pretraining_data_builder/rsi_process/__main__.py
Normal file
19
tools/pretraining_data_builder/rsi_process/__main__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import argparse
|
||||
from rsi_process.adapter import process_adapter
|
||||
|
||||
def get_main_parser():
|
||||
parser = argparse.ArgumentParser(description='RSI Processing Pipeline')
|
||||
parser.add_argument('--fn_img', help='input zip file')
|
||||
parser.add_argument('--save_dir', default='output/', help='prefix on oss bucket')
|
||||
parser.add_argument('--verbose', action='store_true', default=True, help='whether to print info')
|
||||
parser.add_argument('--use_gcj02', action='store_true', default=False, help='whether to use GCJ02 coordinate system')
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_main_parser()
|
||||
args = parser.parse_args()
|
||||
process_adapter(args.fn_img, args.save_dir, args.verbose, args.use_gcj02)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
21
tools/pretraining_data_builder/rsi_process/adapter.py
Normal file
21
tools/pretraining_data_builder/rsi_process/adapter.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from rsi_process.script_s1_tiles import process_s1
|
||||
from rsi_process.script_s2_tiles import process_s2
|
||||
from rsi_process.script_wv_tiles import process_wv
|
||||
import EasyDict as edict
|
||||
|
||||
|
||||
def process_adapter(fn_img, save_dir, verbose, use_gcj02):
|
||||
satellite_info = fn_img.split('/')[-1].split('_')[0]
|
||||
if 'S2' in satellite_info:
|
||||
satellite = 'S2'
|
||||
elif 'S1' in satellite_info:
|
||||
satellite = 'S1'
|
||||
elif 'WV' in satellite_info:
|
||||
satellite = 'WV'
|
||||
args = edict(fn_img=fn_img, save_dir=save_dir, verbose=verbose, use_gcj02=use_gcj02)
|
||||
if satellite == 'S1':
|
||||
process_s1(args)
|
||||
elif satellite == 'S2':
|
||||
process_s2(args)
|
||||
elif satellite == 'WV':
|
||||
process_wv(args)
|
||||
238
tools/pretraining_data_builder/rsi_process/script_s1_tiles.py
Normal file
238
tools/pretraining_data_builder/rsi_process/script_s1_tiles.py
Normal file
@@ -0,0 +1,238 @@
|
||||
import os
|
||||
import uuid
|
||||
import numpy as np
|
||||
import pyproj as prj
|
||||
from osgeo import gdal
|
||||
from time import time
|
||||
import mercantile
|
||||
from PIL import Image
|
||||
import utils_s1
|
||||
import imageio.v2 as iio
|
||||
from tile_resample import (
|
||||
get_tile_array,
|
||||
transfer
|
||||
)
|
||||
|
||||
import argparse
|
||||
from rich import print
|
||||
from rich.progress import track
|
||||
|
||||
def get_args_parser():
|
||||
parser = argparse.ArgumentParser(description='Sentinel-1 to GCJ02 tiles')
|
||||
parser.add_argument('--fn_img', help='input zip file of Sentinel-1 L1C')
|
||||
parser.add_argument('--save_dir', default='output_s1/', help='prefix on oss bucket')
|
||||
parser.add_argument('--verbose', action='store_true', default=True, help='whether to print info')
|
||||
parser.add_argument('--use_gcj02', action='store_true', default=False, help='whether to use GCJ02 coordinate system')
|
||||
return parser
|
||||
|
||||
def process_s1(args):
|
||||
t_start = time()
|
||||
fn_img = args.fn_img
|
||||
max_target_file = fn_img.split('_')[2][0:8]
|
||||
verbose = args.verbose
|
||||
save_rgb = True
|
||||
nodata = 0
|
||||
|
||||
save_dir = args.save_dir
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
thumb_save_dir = os.path.join(save_dir, 'thumb')
|
||||
os.makedirs(thumb_save_dir, exist_ok=True)
|
||||
|
||||
print(f"converting {fn_img}...")
|
||||
|
||||
z = 14
|
||||
bands = ['VV', 'VH']
|
||||
buf = 1
|
||||
|
||||
def get_image_by_approximate_boundary(boundary):
|
||||
'''
|
||||
boundary: iterable of (lng, lat) in wgs84
|
||||
'''
|
||||
arr_lnglat = np.array(boundary)
|
||||
xx, yy = tr_from_4326.transform(arr_lnglat[:, 0], arr_lnglat[:, 1])
|
||||
row_min = int((tr[3] - yy.max()) / yres)
|
||||
row_max = int((tr[3] - yy.min()) / yres)
|
||||
col_min = int((xx.min() - tr[0]) / xres)
|
||||
col_max = int((xx.max() - tr[0]) / xres)
|
||||
row_min = max(0, row_min - buf)
|
||||
row_max = min(ny - 1, row_max + buf)
|
||||
col_min = max(0, col_min - buf)
|
||||
col_max = min(nx - 1, col_max + buf)
|
||||
if row_min > row_max or col_min > col_max:
|
||||
return None
|
||||
|
||||
arr_image = np.stack([
|
||||
ds.ReadAsArray(col_min, row_min, col_max - col_min + 1, row_max - row_min + 1)
|
||||
for ds in list_arr
|
||||
])
|
||||
|
||||
for iband in range(arr_image.shape[0]):
|
||||
if np.any(arr_image[iband] != nodata):
|
||||
break
|
||||
else:
|
||||
return None
|
||||
arr_image = arr_image.transpose((1, 2, 0))
|
||||
if arr_image.shape[2] == 1:
|
||||
arr_image = arr_image[:, :, 0]
|
||||
arr_xx = tr[0] + np.arange(col_min, col_max + 1) * xres
|
||||
arr_yy = tr[3] - np.arange(row_min, row_max + 1) * yres
|
||||
arr_xx, arr_yy = np.meshgrid(arr_xx, arr_yy)
|
||||
arr_lngs, arr_lats = tr_to_4326.transform(arr_xx, arr_yy)
|
||||
return arr_image, arr_lngs, arr_lats
|
||||
|
||||
|
||||
rec = utils_s1.zip2rec(fn_img)
|
||||
# import pdb; pdb.set_trace()
|
||||
os.makedirs(os.path.join(thumb_save_dir, rec['sensing_start'].replace('-', '')), exist_ok=True)
|
||||
thumb_save_path = os.path.join(thumb_save_dir, rec['sensing_start'].replace('-', ''), rec['product_uri'].replace('SAFE', 'png'))
|
||||
iio.imwrite(thumb_save_path, rec['thumb'])
|
||||
|
||||
list_arr = []
|
||||
for band in bands:
|
||||
fn_jp2 = utils_s1.make_full_name(rec, band=band)
|
||||
# import pdb; pdb.set_trace()
|
||||
fn_jp2 = '/vsizip/' + os.path.join(fn_img, fn_jp2)
|
||||
ds = gdal.Open(fn_jp2)
|
||||
list_arr.append(ds)
|
||||
if band == bands[0]:
|
||||
nx, ny = ds.RasterXSize, ds.RasterYSize
|
||||
if verbose: print('input size:', nx, ny)
|
||||
tr = ds.GetGeoTransform()
|
||||
if verbose:
|
||||
print(gdal.Info(ds, format='json'))
|
||||
# import pdb; pdb.set_trace()
|
||||
try:
|
||||
proj_wkt = ds.GetProjectionRef()
|
||||
if proj_wkt:
|
||||
srs = prj.CRS.from_wkt(proj_wkt)
|
||||
epsg = int(srs.to_epsg())
|
||||
else:
|
||||
proj_wkt = ds.GetGCPProjection()
|
||||
if proj_wkt:
|
||||
srs = prj.CRS.from_wkt(proj_wkt)
|
||||
epsg = int(srs.to_epsg())
|
||||
else:
|
||||
print("Warning: No projection information found, using default value 4326 (WGS84)")
|
||||
epsg = 4326
|
||||
except Exception as e:
|
||||
print(f"Warning: Unable to get EPSG code, using default value 4326 (WGS84). Error: {e}")
|
||||
epsg = 4326
|
||||
|
||||
if verbose:
|
||||
print(f"Used EPSG code: {epsg}")
|
||||
|
||||
size_pixel = mercantile.CE / 2 ** z / 256
|
||||
radius = np.ceil(max(tr[1], -tr[5]) / size_pixel * 1.5)
|
||||
|
||||
buf_ext = buf
|
||||
xmin = tr[0] - buf_ext * tr[1]
|
||||
ymin = tr[3] + (ny + buf_ext) * tr[5]
|
||||
xmax = tr[0] + (nx + buf_ext) * tr[1]
|
||||
ymax = tr[3] - buf_ext * tr[5]
|
||||
xres = tr[1]
|
||||
yres = - tr[5]
|
||||
if verbose:
|
||||
print(
|
||||
f'input extent, WGS84, buffered by {buf_ext} pixels: {xmin}, {ymin}, {xmax}, {ymax}'
|
||||
)
|
||||
|
||||
tr_to_4326 = prj.Transformer.from_crs(epsg, 4326, always_xy=True)
|
||||
tr_from_4326 = prj.Transformer.from_crs(4326, epsg, always_xy=True)
|
||||
arr_lng, arr_lat = tr_to_4326.transform(
|
||||
np.array([xmin, xmin, xmax, xmax]),
|
||||
np.array([ymax, ymin, ymin, ymax])
|
||||
)
|
||||
# import pdb; pdb.set_trace()
|
||||
if args.use_gcj02:
|
||||
arr_lng_final, arr_lat_final = transfer.WGS84_to_GCJ02(arr_lng, arr_lat)
|
||||
else:
|
||||
arr_lng_final, arr_lat_final = arr_lng, arr_lat
|
||||
|
||||
box = (
|
||||
arr_lng_final.min(),
|
||||
arr_lat_final.min(),
|
||||
arr_lng_final.max(),
|
||||
arr_lat_final.max()
|
||||
)
|
||||
|
||||
if verbose:
|
||||
coord_system = "GCJ02" if args.use_gcj02 else "WGS84"
|
||||
print(f'input extent, {coord_system}: {box}')
|
||||
|
||||
tile_ul = mercantile.tile(box[0], box[3], z)
|
||||
tile_lr = mercantile.tile(box[2], box[1], z)
|
||||
|
||||
if verbose:
|
||||
print('Upperleft ', str(tile_ul))
|
||||
print('Lowerright ', str(tile_lr))
|
||||
|
||||
def work(x, y, z, save_rgb):
|
||||
arr_tile = get_tile_array(
|
||||
x, y, z,
|
||||
method='nearest',
|
||||
func_source=get_image_by_approximate_boundary,
|
||||
radius=radius,
|
||||
use_gc02=args.use_gcj02
|
||||
)
|
||||
y_str = str(y)
|
||||
if arr_tile is not None:
|
||||
indi_gap = arr_tile[:, :, 0] == 0
|
||||
|
||||
dict_arr = {
|
||||
band: arr_tile[:, :, i_band]
|
||||
for i_band, band in enumerate(bands)
|
||||
}
|
||||
save_path = os.path.join(save_dir, str(z), str(x))
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
|
||||
npz_filename = os.path.join(save_path, f'{y_str}_{max_target_file}.npz')
|
||||
|
||||
if indi_gap.any():
|
||||
if os.path.exists(npz_filename):
|
||||
try:
|
||||
fp = np.load(npz_filename)
|
||||
for band in bands:
|
||||
dict_arr[band][indi_gap] = fp[band][indi_gap]
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("datasize is 0", npz_filename)
|
||||
pass
|
||||
|
||||
np.savez_compressed(npz_filename, **dict_arr)
|
||||
if verbose:
|
||||
print(f"npz file for X={str(x)}, Y={y_str}, Z={str(z)} date={max_target_file} generated!")
|
||||
if save_rgb:
|
||||
arr_rgb = np.stack([dict_arr['B4'], dict_arr['B3'], dict_arr['B2']], axis=-1)
|
||||
arr_rgb = np.clip(arr_rgb / 3000. * 255, 0, 255).astype(np.uint8)
|
||||
image_tile = Image.fromarray(arr_rgb)
|
||||
|
||||
png_filename = os.path.join(save_path, f'{y_str}_{max_target_file}.png')
|
||||
image_tile.save(png_filename, format='png')
|
||||
|
||||
diff_list = []
|
||||
|
||||
tasks = [
|
||||
(x, y) for x in range(tile_ul.x, tile_lr.x + 1)
|
||||
for y in range(tile_ul.y, tile_lr.y + 1)
|
||||
]
|
||||
|
||||
for x, y in track(tasks, description="converting tiles..."):
|
||||
work(x, y, z, save_rgb)
|
||||
diff_list.append(os.path.join(str(z), str(x), f'{y}_{max_target_file}.npz'))
|
||||
|
||||
diff_path = os.path.join(save_dir, 'diff', 'new')
|
||||
os.makedirs(diff_path, exist_ok=True)
|
||||
diff_filename = os.path.join(diff_path, f"{z}-{os.path.splitext(os.path.basename(fn_img))[0]}-{uuid.uuid1()}.txt")
|
||||
with open(diff_filename, 'w') as f:
|
||||
f.write('\n'.join(diff_list))
|
||||
|
||||
print("time cost :", time() - t_start)
|
||||
|
||||
def main():
|
||||
args = get_args_parser().parse_args()
|
||||
process_s1(args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
226
tools/pretraining_data_builder/rsi_process/script_s2_tiles.py
Normal file
226
tools/pretraining_data_builder/rsi_process/script_s2_tiles.py
Normal file
@@ -0,0 +1,226 @@
|
||||
import os
|
||||
import uuid
|
||||
import numpy as np
|
||||
import pyproj as prj
|
||||
from osgeo import gdal
|
||||
from time import time
|
||||
import mercantile
|
||||
from PIL import Image
|
||||
import utils_s2
|
||||
import imageio.v2 as iio
|
||||
from tile_resample import (
|
||||
get_tile_array,
|
||||
transfer
|
||||
)
|
||||
|
||||
import argparse
|
||||
from rich import print
|
||||
from rich.progress import track
|
||||
|
||||
def get_args_parser():
|
||||
parser = argparse.ArgumentParser(description='Sentinel-2 to GCJ02 tiles')
|
||||
parser.add_argument('--fn_img', help='input zip file of Sentinel-2 L2A') # /Users/wukang/Projects/sentinel2-downloader/S2A_MSIL2A_20220615T024601_N0400_R132_T50SNA_20220615T062308.zip
|
||||
parser.add_argument('--resolution', type=int, help='10 or 20 meter resolution bands')
|
||||
parser.add_argument('--save_dir', default='output_s2/', help='prefix on oss bucket')
|
||||
parser.add_argument('--verbose', action='store_true', default=True, help='whether to print info')
|
||||
parser.add_argument('--use_gcj02', action='store_true', default=False, help='whether to use GCJ02 coordinate system')
|
||||
return parser.parse_args()
|
||||
|
||||
def process_s2(args):
|
||||
t_start = time()
|
||||
fn_img = args.fn_img
|
||||
max_target_file = fn_img.split('_')[2][0:8]
|
||||
resolution = args.resolution
|
||||
verbose = args.verbose
|
||||
save_rgb = True
|
||||
nodata = 0
|
||||
|
||||
save_dir = args.save_dir
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
thumb_save_dir = os.path.join(save_dir, 'thumb')
|
||||
os.makedirs(thumb_save_dir, exist_ok=True)
|
||||
|
||||
print(f"converting {fn_img}...")
|
||||
if resolution == 10:
|
||||
z = 14
|
||||
bands = ['B2', 'B3', 'B4', 'B8']
|
||||
buf = 1
|
||||
elif resolution == 20:
|
||||
z = 13
|
||||
bands = ['B5', 'B6', 'B7', 'B8A', 'B11', 'B12', 'SCL']
|
||||
buf = 1
|
||||
save_rgb = False
|
||||
else:
|
||||
raise Exception(f'Unknown resoluiton: {resolution}')
|
||||
|
||||
|
||||
def get_image_by_approximate_boundary(boundary):
|
||||
'''
|
||||
boundary: iterable of (lng, lat) in wgs84
|
||||
'''
|
||||
arr_lnglat = np.array(boundary)
|
||||
xx, yy = tr_from_4326.transform(arr_lnglat[:, 0], arr_lnglat[:, 1])
|
||||
row_min = int((tr[3] - yy.max()) / yres)
|
||||
row_max = int((tr[3] - yy.min()) / yres)
|
||||
col_min = int((xx.min() - tr[0]) / xres)
|
||||
col_max = int((xx.max() - tr[0]) / xres)
|
||||
row_min = max(0, row_min - buf)
|
||||
row_max = min(ny - 1, row_max + buf)
|
||||
col_min = max(0, col_min - buf)
|
||||
col_max = min(nx - 1, col_max + buf)
|
||||
if row_min > row_max or col_min > col_max:
|
||||
return None
|
||||
|
||||
arr_image = np.stack([
|
||||
ds.ReadAsArray(col_min, row_min, col_max - col_min + 1, row_max - row_min + 1)
|
||||
for ds in list_arr
|
||||
])
|
||||
|
||||
for iband in range(arr_image.shape[0]):
|
||||
if np.any(arr_image[iband] != nodata):
|
||||
break
|
||||
else:
|
||||
return None
|
||||
arr_image = arr_image.transpose((1, 2, 0))
|
||||
if arr_image.shape[2] == 1:
|
||||
arr_image = arr_image[:, :, 0]
|
||||
arr_xx = tr[0] + np.arange(col_min, col_max + 1) * xres
|
||||
arr_yy = tr[3] - np.arange(row_min, row_max + 1) * yres
|
||||
arr_xx, arr_yy = np.meshgrid(arr_xx, arr_yy)
|
||||
arr_lngs, arr_lats = tr_to_4326.transform(arr_xx, arr_yy)
|
||||
return arr_image, arr_lngs, arr_lats
|
||||
|
||||
|
||||
rec = utils_s2.zip2rec(fn_img)
|
||||
|
||||
os.makedirs(os.path.join(thumb_save_dir, rec['sensing_start'].replace('-', '')), exist_ok=True)
|
||||
thumb_save_path = os.path.join(thumb_save_dir, rec['sensing_start'].replace('-', ''), rec['product_uri'].replace('SAFE', 'png'))
|
||||
iio.imwrite(thumb_save_path, rec['thumb'])
|
||||
|
||||
list_arr = []
|
||||
for band in bands:
|
||||
fn_jp2 = utils_s2.make_full_name(rec, band=band)
|
||||
fn_jp2 = '/vsizip/' + os.path.join(fn_img, fn_jp2)
|
||||
ds = gdal.Open(fn_jp2)
|
||||
list_arr.append(ds)
|
||||
if band == bands[0]:
|
||||
nx, ny = ds.RasterXSize, ds.RasterYSize
|
||||
if verbose: print('input size:', nx, ny)
|
||||
tr = ds.GetGeoTransform()
|
||||
if verbose:
|
||||
print(gdal.Info(ds, format='json'))
|
||||
epsg = int(
|
||||
gdal.Info(ds, format='json')['coordinateSystem']['wkt'].rsplit('"EPSG",', 1)[-1][:-2]
|
||||
)
|
||||
|
||||
size_pixel = mercantile.CE / 2 ** z / 256
|
||||
radius = np.ceil(max(tr[1], -tr[5]) / size_pixel * 1.5)
|
||||
|
||||
buf_ext = buf
|
||||
xmin = tr[0] - buf_ext * tr[1]
|
||||
ymin = tr[3] + (ny + buf_ext) * tr[5]
|
||||
xmax = tr[0] + (nx + buf_ext) * tr[1]
|
||||
ymax = tr[3] - buf_ext * tr[5]
|
||||
xres = tr[1]
|
||||
yres = - tr[5]
|
||||
if verbose:
|
||||
print(
|
||||
f'input extent, WGS84, buffered by {buf_ext} pixels: {xmin}, {ymin}, {xmax}, {ymax}'
|
||||
)
|
||||
|
||||
tr_to_4326 = prj.Transformer.from_crs(epsg, 4326, always_xy=True)
|
||||
tr_from_4326 = prj.Transformer.from_crs(4326, epsg, always_xy=True)
|
||||
arr_lng, arr_lat = tr_to_4326.transform(
|
||||
np.array([xmin, xmin, xmax, xmax]),
|
||||
np.array([ymax, ymin, ymin, ymax])
|
||||
)
|
||||
|
||||
if args.use_gcj02:
|
||||
arr_lng_final, arr_lat_final = transfer.WGS84_to_GCJ02(arr_lng, arr_lat)
|
||||
else:
|
||||
arr_lng_final, arr_lat_final = arr_lng, arr_lat
|
||||
|
||||
box = (
|
||||
arr_lng_final.min(),
|
||||
arr_lat_final.min(),
|
||||
arr_lng_final.max(),
|
||||
arr_lat_final.max()
|
||||
)
|
||||
|
||||
if verbose:
|
||||
coord_system = "GCJ02" if args.use_gcj02 else "WGS84"
|
||||
print(f'input extent, {coord_system}: {box}')
|
||||
|
||||
tile_ul = mercantile.tile(box[0], box[3], z)
|
||||
tile_lr = mercantile.tile(box[2], box[1], z)
|
||||
|
||||
if verbose:
|
||||
print('Upperleft ', str(tile_ul))
|
||||
print('Lowerright ', str(tile_lr))
|
||||
|
||||
def work(x, y, z, save_rgb):
|
||||
arr_tile = get_tile_array(
|
||||
x, y, z,
|
||||
method='nearest',
|
||||
func_source=get_image_by_approximate_boundary,
|
||||
radius=radius,
|
||||
use_gc02=args.use_gcj02
|
||||
)
|
||||
y_str = str(y)
|
||||
if arr_tile is not None:
|
||||
indi_gap = arr_tile[:, :, 0] == 0
|
||||
|
||||
dict_arr = {
|
||||
band: arr_tile[:, :, i_band]
|
||||
for i_band, band in enumerate(bands)
|
||||
}
|
||||
save_path = os.path.join(save_dir, str(z), str(x))
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
|
||||
npz_filename = os.path.join(save_path, f'{y_str}_{max_target_file}.npz')
|
||||
|
||||
if indi_gap.any():
|
||||
if os.path.exists(npz_filename):
|
||||
try:
|
||||
fp = np.load(npz_filename)
|
||||
for band in bands:
|
||||
dict_arr[band][indi_gap] = fp[band][indi_gap]
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("datasize is 0", npz_filename)
|
||||
pass
|
||||
|
||||
np.savez_compressed(npz_filename, **dict_arr)
|
||||
if verbose:
|
||||
print(f"npz file for X={str(x)}, Y={y_str}, Z={str(z)} date={max_target_file} generated!")
|
||||
if save_rgb:
|
||||
arr_rgb = np.stack([dict_arr['B4'], dict_arr['B3'], dict_arr['B2']], axis=-1)
|
||||
arr_rgb = np.clip(arr_rgb / 3000. * 255, 0, 255).astype(np.uint8)
|
||||
image_tile = Image.fromarray(arr_rgb)
|
||||
|
||||
png_filename = os.path.join(save_path, f'{y_str}_{max_target_file}.png')
|
||||
image_tile.save(png_filename, format='png')
|
||||
|
||||
diff_list = []
|
||||
|
||||
tasks = [
|
||||
(x, y) for x in range(tile_ul.x, tile_lr.x + 1)
|
||||
for y in range(tile_ul.y, tile_lr.y + 1)
|
||||
]
|
||||
|
||||
for x, y in track(tasks, description="converting tiles..."):
|
||||
work(x, y, z, save_rgb)
|
||||
diff_list.append(os.path.join(str(z), str(x), f'{y}_{max_target_file}.npz'))
|
||||
|
||||
diff_path = os.path.join(save_dir, 'diff', 'new')
|
||||
os.makedirs(diff_path, exist_ok=True)
|
||||
diff_filename = os.path.join(diff_path, f"{z}-{os.path.splitext(os.path.basename(fn_img))[0]}-{uuid.uuid1()}.txt")
|
||||
with open(diff_filename, 'w') as f:
|
||||
f.write('\n'.join(diff_list))
|
||||
|
||||
print("time cost :", time() - t_start)
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_args_parser()
|
||||
process_s2(args)
|
||||
183
tools/pretraining_data_builder/rsi_process/script_wv_tiles.py
Normal file
183
tools/pretraining_data_builder/rsi_process/script_wv_tiles.py
Normal file
@@ -0,0 +1,183 @@
|
||||
import os
|
||||
import uuid
|
||||
import numpy as np
|
||||
import pyproj as prj
|
||||
from osgeo import gdal
|
||||
from time import time
|
||||
import mercantile
|
||||
from PIL import Image
|
||||
import imageio.v2 as iio
|
||||
from tile_resample import (
|
||||
get_tile_array,
|
||||
transfer
|
||||
)
|
||||
import argparse
|
||||
from rich import print
|
||||
from rich.progress import track
|
||||
|
||||
def get_args_parser():
|
||||
parser = argparse.ArgumentParser(description='WorldView to tiles')
|
||||
parser.add_argument('--fn_img', help='input file of WorldView image')
|
||||
parser.add_argument('--save_dir', default='output_wv/', help='output directory')
|
||||
parser.add_argument('--zoom', type=int, default=16, help='zoom level')
|
||||
parser.add_argument('--verbose', action='store_true', default=True)
|
||||
parser.add_argument('--use_gcj02', action='store_true', default=False)
|
||||
return parser.parse_args()
|
||||
|
||||
def get_image_by_approximate_boundary(ds_list, boundary, tr, buf=1):
|
||||
'''Get image data within a specified boundary
|
||||
|
||||
Args:
|
||||
ds_list: List of GDAL datasets
|
||||
boundary: List of (lng, lat) coordinates
|
||||
tr: Geotransformation parameters
|
||||
buf: Buffer size
|
||||
'''
|
||||
arr_lnglat = np.array(boundary)
|
||||
tr_from_4326 = prj.Transformer.from_crs(4326, ds_list[0].GetProjection(), always_xy=True)
|
||||
|
||||
xx, yy = tr_from_4326.transform(arr_lnglat[:, 0], arr_lnglat[:, 1])
|
||||
|
||||
nx = ds_list[0].RasterXSize
|
||||
ny = ds_list[0].RasterYSize
|
||||
xres = tr[1]
|
||||
yres = -tr[5]
|
||||
|
||||
row_min = int((tr[3] - yy.max()) / yres)
|
||||
row_max = int((tr[3] - yy.min()) / yres)
|
||||
col_min = int((xx.min() - tr[0]) / xres)
|
||||
col_max = int((xx.max() - tr[0]) / xres)
|
||||
|
||||
row_min = max(0, row_min - buf)
|
||||
row_max = min(ny - 1, row_max + buf)
|
||||
col_min = max(0, col_min - buf)
|
||||
col_max = min(nx - 1, col_max + buf)
|
||||
|
||||
if row_min > row_max or col_min > col_max:
|
||||
return None
|
||||
|
||||
arr_image = np.stack([
|
||||
ds.ReadAsArray(col_min, row_min, col_max - col_min + 1, row_max - row_min + 1)
|
||||
for ds in ds_list
|
||||
])
|
||||
|
||||
if np.all(arr_image == 0):
|
||||
return None
|
||||
|
||||
arr_image = arr_image.transpose((1, 2, 0))
|
||||
|
||||
arr_xx = tr[0] + np.arange(col_min, col_max + 1) * xres
|
||||
arr_yy = tr[3] - np.arange(row_min, row_max + 1) * yres
|
||||
arr_xx, arr_yy = np.meshgrid(arr_xx, arr_yy)
|
||||
|
||||
tr_to_4326 = prj.Transformer.from_crs(ds_list[0].GetProjection(), 4326, always_xy=True)
|
||||
arr_lngs, arr_lats = tr_to_4326.transform(arr_xx, arr_yy)
|
||||
|
||||
return arr_image, arr_lngs, arr_lats
|
||||
|
||||
def process_wv(args):
|
||||
t_start = time()
|
||||
|
||||
fn_img = args.fn_img
|
||||
save_dir = args.save_dir
|
||||
z = args.zoom
|
||||
verbose = args.verbose
|
||||
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
ds = gdal.Open(fn_img)
|
||||
if ds is None:
|
||||
raise Exception(f"Cannot open {fn_img}")
|
||||
|
||||
bands = [ds.GetRasterBand(i+1) for i in range(ds.RasterCount)]
|
||||
list_arr = [ds]
|
||||
|
||||
nx, ny = ds.RasterXSize, ds.RasterYSize
|
||||
tr = ds.GetGeoTransform()
|
||||
|
||||
if verbose:
|
||||
print('Input size:', nx, ny)
|
||||
print(gdal.Info(ds, format='json'))
|
||||
|
||||
# Calculate the image range
|
||||
size_pixel = mercantile.CE / 2 ** z / 256
|
||||
radius = np.ceil(max(tr[1], -tr[5]) / size_pixel * 1.5)
|
||||
|
||||
buf_ext = 1
|
||||
xmin = tr[0] - buf_ext * tr[1]
|
||||
ymin = tr[3] + (ny + buf_ext) * tr[5]
|
||||
xmax = tr[0] + (nx + buf_ext) * tr[1]
|
||||
ymax = tr[3] - buf_ext * tr[5]
|
||||
|
||||
tr_to_4326 = prj.Transformer.from_crs(ds.GetProjection(), 4326, always_xy=True)
|
||||
arr_lng, arr_lat = tr_to_4326.transform(
|
||||
np.array([xmin, xmin, xmax, xmax]),
|
||||
np.array([ymax, ymin, ymin, ymax])
|
||||
)
|
||||
|
||||
if args.use_gcj02:
|
||||
arr_lng_final, arr_lat_final = transfer.WGS84_to_GCJ02(arr_lng, arr_lat)
|
||||
else:
|
||||
arr_lng_final, arr_lat_final = arr_lng, arr_lat
|
||||
|
||||
box = (
|
||||
arr_lng_final.min(),
|
||||
arr_lat_final.min(),
|
||||
arr_lng_final.max(),
|
||||
arr_lat_final.max()
|
||||
)
|
||||
|
||||
if verbose:
|
||||
coord_system = "GCJ02" if args.use_gcj02 else "WGS84"
|
||||
print(f'Input extent, {coord_system}: {box}')
|
||||
|
||||
# Calculate the tile range to be processed
|
||||
tile_ul = mercantile.tile(box[0], box[3], z)
|
||||
tile_lr = mercantile.tile(box[2], box[1], z)
|
||||
|
||||
if verbose:
|
||||
print('Upperleft ', str(tile_ul))
|
||||
print('Lowerright ', str(tile_lr))
|
||||
|
||||
def work(x, y, z):
|
||||
arr_tile = get_tile_array(
|
||||
x, y, z,
|
||||
method='nearest',
|
||||
func_source=lambda boundary: get_image_by_approximate_boundary(list_arr, boundary, tr),
|
||||
radius=radius,
|
||||
use_gc02=args.use_gcj02
|
||||
)
|
||||
|
||||
if arr_tile is not None:
|
||||
save_path = os.path.join(save_dir, str(z), str(x))
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
|
||||
# Save as PNG
|
||||
if arr_tile.shape[2] >= 3:
|
||||
arr_rgb = arr_tile[:, :, :3]
|
||||
arr_rgb = np.clip(arr_rgb / 2000. * 255, 0, 255).astype(np.uint8)
|
||||
image_tile = Image.fromarray(arr_rgb)
|
||||
png_filename = os.path.join(save_path, f'{y}.png')
|
||||
image_tile.save(png_filename, format='png')
|
||||
|
||||
# Save as NPZ
|
||||
dict_arr = {f'B{i+1}': arr_tile[:, :, i] for i in range(arr_tile.shape[2])}
|
||||
npz_filename = os.path.join(save_path, f'{y}.npz')
|
||||
np.savez_compressed(npz_filename, **dict_arr)
|
||||
|
||||
tasks = [
|
||||
(x, y) for x in range(tile_ul.x, tile_lr.x + 1)
|
||||
for y in range(tile_ul.y, tile_lr.y + 1)
|
||||
]
|
||||
|
||||
for x, y in track(tasks, description="Converting tiles..."):
|
||||
work(x, y, z)
|
||||
|
||||
print("Time cost:", time() - t_start)
|
||||
|
||||
def main():
|
||||
args = get_args_parser()
|
||||
process_wv(args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
230
tools/pretraining_data_builder/rsi_process/tile_resample.py
Normal file
230
tools/pretraining_data_builder/rsi_process/tile_resample.py
Normal file
@@ -0,0 +1,230 @@
|
||||
import numpy as np
|
||||
import mercantile
|
||||
from pyresample import bilinear, kd_tree, geometry
|
||||
|
||||
TILE_SIZE = 256
|
||||
|
||||
class LngLatTransfer():
|
||||
|
||||
def __init__(self):
|
||||
self.x_pi = 3.14159265358979324 * 3000.0 / 180.0
|
||||
self.pi = np.pi # π
|
||||
self.a = 6378245.0
|
||||
self.es = 0.00669342162296594323
|
||||
pass
|
||||
|
||||
def GCJ02_to_BD09(self, gcj_lng, gcj_lat):
|
||||
"""
|
||||
Convert coordinates from GCJ02 to BD09 coordinate system
|
||||
:param lng: Longitude in GCJ02 coordinate system
|
||||
:param lat: Latitude in GCJ02 coordinate system
|
||||
:return: Converted longitude and latitude in BD09
|
||||
"""
|
||||
z = np.sqrt(gcj_lng * gcj_lng + gcj_lat * gcj_lat) + 0.00002 * np.sin(gcj_lat * self.x_pi)
|
||||
theta = np.arctan2(gcj_lat, gcj_lng) + 0.000003 * np.cos(gcj_lng * self.x_pi)
|
||||
bd_lng = z * np.cos(theta) + 0.0065
|
||||
bd_lat = z * np.sin(theta) + 0.006
|
||||
return bd_lng, bd_lat
|
||||
|
||||
|
||||
def BD09_to_GCJ02(self, bd_lng, bd_lat):
|
||||
'''
|
||||
Convert coordinates from BD09 to GCJ02 coordinate system
|
||||
:param bd_lng: Longitude in BD09 coordinate system
|
||||
:param bd_lat: Latitude in BD09 coordinate system
|
||||
:return: Converted longitude and latitude in GCJ02
|
||||
'''
|
||||
x = bd_lng - 0.0065
|
||||
y = bd_lat - 0.006
|
||||
z = np.sqrt(x * x + y * y) - 0.00002 * np.sin(y * self.x_pi)
|
||||
theta = np.arctan2(y, x) - 0.000003 * np.cos(x * self.x_pi)
|
||||
gcj_lng = z * np.cos(theta)
|
||||
gcj_lat = z * np.sin(theta)
|
||||
return gcj_lng, gcj_lat
|
||||
|
||||
|
||||
def WGS84_to_GCJ02(self, lng, lat):
|
||||
'''
|
||||
Convert coordinates from WGS84 to GCJ02 coordinate system
|
||||
:param lng: Longitude in WGS84 coordinate system
|
||||
:param lat: Latitude in WGS84 coordinate system
|
||||
:return: Converted longitude and latitude in GCJ02
|
||||
'''
|
||||
dlat = self._transformlat(lng - 105.0, lat - 35.0)
|
||||
dlng = self._transformlng(lng - 105.0, lat - 35.0)
|
||||
radlat = lat / 180.0 * self.pi
|
||||
magic = np.sin(radlat)
|
||||
magic = 1 - self.es * magic * magic
|
||||
sqrtmagic = np.sqrt(magic)
|
||||
dlat = (dlat * 180.0) / ((self.a * (1 - self.es)) / (magic * sqrtmagic) * self.pi)
|
||||
dlng = (dlng * 180.0) / (self.a / sqrtmagic * np.cos(radlat) * self.pi)
|
||||
gcj_lng = lng + dlng
|
||||
gcj_lat = lat + dlat
|
||||
return gcj_lng, gcj_lat
|
||||
|
||||
|
||||
def GCJ02_to_WGS84(self, gcj_lng, gcj_lat):
|
||||
'''
|
||||
Convert coordinates from GCJ02 to WGS84 coordinate system
|
||||
:param gcj_lng: Longitude in GCJ02 coordinate system
|
||||
:param gcj_lat: Latitude in GCJ02 coordinate system
|
||||
:return: Converted longitude and latitude in WGS84
|
||||
'''
|
||||
dlat = self._transformlat(gcj_lng - 105.0, gcj_lat - 35.0)
|
||||
dlng = self._transformlng(gcj_lng - 105.0, gcj_lat - 35.0)
|
||||
radlat = gcj_lat / 180.0 * self.pi
|
||||
magic = np.sin(radlat)
|
||||
magic = 1 - self.es * magic * magic
|
||||
sqrtmagic = np.sqrt(magic)
|
||||
dlat = (dlat * 180.0) / ((self.a * (1 - self.es)) / (magic * sqrtmagic) * self.pi)
|
||||
dlng = (dlng * 180.0) / (self.a / sqrtmagic * np.cos(radlat) * self.pi)
|
||||
mglat = gcj_lat + dlat
|
||||
mglng = gcj_lng + dlng
|
||||
lng = gcj_lng * 2 - mglng
|
||||
lat = gcj_lat * 2 - mglat
|
||||
return lng, lat
|
||||
|
||||
|
||||
def BD09_to_WGS84(self, bd_lng, bd_lat):
|
||||
'''
|
||||
Convert coordinates from BD09 to WGS84 coordinate system
|
||||
:param bd_lng: Longitude in BD09 coordinate system
|
||||
:param bd_lat: Latitude in BD09 coordinate system
|
||||
:return: Converted longitude and latitude in WGS84
|
||||
'''
|
||||
lng, lat = self.BD09_to_GCJ02(bd_lng, bd_lat)
|
||||
return self.GCJ02_to_WGS84(lng, lat)
|
||||
|
||||
|
||||
def WGS84_to_BD09(self, lng, lat):
|
||||
'''
|
||||
Convert coordinates from WGS84 to BD09 coordinate system
|
||||
:param lng: Longitude in WGS84 coordinate system
|
||||
:param lat: Latitude in WGS84 coordinate system
|
||||
:return: Converted longitude and latitude in BD09
|
||||
'''
|
||||
lng, lat = self.WGS84_to_GCJ02(lng, lat)
|
||||
return self.GCJ02_to_BD09(lng, lat)
|
||||
|
||||
|
||||
def _transformlat(self, lng, lat):
|
||||
ret = -100.0 + 2.0 * lng + 3.0 * lat + 0.2 * lat * lat + \
|
||||
0.1 * lng * lat + 0.2 * np.sqrt(np.fabs(lng))
|
||||
ret += (20.0 * np.sin(6.0 * lng * self.pi) + 20.0 *
|
||||
np.sin(2.0 * lng * self.pi)) * 2.0 / 3.0
|
||||
ret += (20.0 * np.sin(lat * self.pi) + 40.0 *
|
||||
np.sin(lat / 3.0 * self.pi)) * 2.0 / 3.0
|
||||
ret += (160.0 * np.sin(lat / 12.0 * self.pi) + 320 *
|
||||
np.sin(lat * self.pi / 30.0)) * 2.0 / 3.0
|
||||
return ret
|
||||
|
||||
|
||||
def _transformlng(self, lng, lat):
|
||||
ret = 300.0 + lng + 2.0 * lat + 0.1 * lng * lng + \
|
||||
0.1 * lng * lat + 0.1 * np.sqrt(np.fabs(lng))
|
||||
ret += (20.0 * np.sin(6.0 * lng * self.pi) + 20.0 *
|
||||
np.sin(2.0 * lng * self.pi)) * 2.0 / 3.0
|
||||
ret += (20.0 * np.sin(lng * self.pi) + 40.0 *
|
||||
np.sin(lng / 3.0 * self.pi)) * 2.0 / 3.0
|
||||
ret += (150.0 * np.sin(lng / 12.0 * self.pi) + 300.0 *
|
||||
np.sin(lng / 30.0 * self.pi)) * 2.0 / 3.0
|
||||
return ret
|
||||
|
||||
def WGS84_to_WebMercator(self, lng, lat):
|
||||
'''
|
||||
Convert coordinates from WGS84 to Web Mercator
|
||||
:param lng: Longitude in WGS84
|
||||
:param lat: Latitude in WGS84
|
||||
:return: Converted Web Mercator coordinates
|
||||
'''
|
||||
x = lng * 20037508.342789 / 180
|
||||
y = np.log(np.tan((90 + lat) * self.pi / 360)) / (self.pi / 180)
|
||||
y = y * 20037508.34789 / 180
|
||||
return x, y
|
||||
|
||||
def WebMercator_to_WGS84(self, x, y):
|
||||
'''
|
||||
Convert coordinates from Web Mercator to WGS84
|
||||
:param x: Web Mercator x coordinate
|
||||
:param y: Web Mercator y coordinate
|
||||
:return: Converted longitude and latitude in WGS84
|
||||
'''
|
||||
lng = x / 20037508.34 * 180
|
||||
lat = y / 20037508.34 * 180
|
||||
lat = 180 / self.pi * (2 * np.arctan(np.exp(lat * self.pi / 180)) - self.pi / 2)
|
||||
return lng, lat
|
||||
|
||||
|
||||
transfer = LngLatTransfer()
|
||||
def get_tile_array(x, y, z, method='nearest', func_source=None, radius=2, fill_value=0, use_gc02=True):
|
||||
"""Resample source image data to map tile
|
||||
|
||||
Args:
|
||||
x, y, z: Tile coordinates
|
||||
method: Resampling method ('nearest' or 'bilinear')
|
||||
func_source: Function to get source image data
|
||||
radius: Search radius in pixels
|
||||
fill_value: Value for no data areas
|
||||
gc02: Whether the coordinates are in GCJ02 system (True) or WGS84 (False)
|
||||
|
||||
Returns:
|
||||
ndarray: Resampled tile data
|
||||
"""
|
||||
bounds = mercantile.bounds(x, y, z)
|
||||
|
||||
if use_gc02:
|
||||
# Convert coordinates from GCJ02 to WGS84
|
||||
wgs84_lngs, wgs84_lats = transfer.GCJ02_to_WGS84(
|
||||
gcj_lng=np.array([bounds.west, bounds.west, bounds.east, bounds.east]),
|
||||
gcj_lat=np.array([bounds.north, bounds.south, bounds.south, bounds.north])
|
||||
)
|
||||
boundary = list(zip(wgs84_lngs, wgs84_lats))
|
||||
else:
|
||||
boundary = list(zip(
|
||||
[bounds.west, bounds.west, bounds.east, bounds.east],
|
||||
[bounds.north, bounds.south, bounds.south, bounds.north]
|
||||
))
|
||||
|
||||
source_data = func_source(boundary)
|
||||
|
||||
if source_data is None:
|
||||
return None
|
||||
|
||||
arr_image, arr_lngs, arr_lats = source_data
|
||||
|
||||
if use_gc02:
|
||||
gcj02_lngs, gcj02_lats = transfer.WGS84_to_GCJ02(arr_lngs, arr_lats)
|
||||
else:
|
||||
gcj02_lngs, gcj02_lats = arr_lngs, arr_lats
|
||||
|
||||
# Define source and target geometries
|
||||
source_def = geometry.SwathDefinition(lons=gcj02_lngs, lats=gcj02_lats)
|
||||
|
||||
xy_bounds = mercantile.xy_bounds(x, y, z)
|
||||
target_def = geometry.AreaDefinition(
|
||||
'tile', 'tile', 'tile',
|
||||
'EPSG:3857',
|
||||
TILE_SIZE, TILE_SIZE,
|
||||
(xy_bounds.left, xy_bounds.bottom, xy_bounds.right, xy_bounds.top)
|
||||
)
|
||||
|
||||
# Resample
|
||||
pixel_size = mercantile.CE / 2 ** z / TILE_SIZE
|
||||
if method == 'nearest':
|
||||
result = kd_tree.resample_nearest(
|
||||
source_def, arr_image, target_def,
|
||||
radius_of_influence=radius * pixel_size,
|
||||
fill_value=fill_value
|
||||
)
|
||||
elif method == 'bilinear':
|
||||
resampler = bilinear.NumpyBilinearResampler(
|
||||
source_def, target_def,
|
||||
radius_of_influence=radius * pixel_size,
|
||||
neighbours=8
|
||||
)
|
||||
result = resampler.resample(arr_image).astype(arr_image.dtype)
|
||||
else:
|
||||
raise ValueError(f'Unknown resampling method: {method}')
|
||||
|
||||
return result
|
||||
|
||||
133
tools/pretraining_data_builder/rsi_process/utils_s1.py
Normal file
133
tools/pretraining_data_builder/rsi_process/utils_s1.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import xml.dom.minidom
|
||||
import os
|
||||
from glob import glob
|
||||
import zipfile
|
||||
from shapely import wkt
|
||||
import geopandas as gpd
|
||||
from osgeo import gdal
|
||||
import imageio.v2 as iio
|
||||
|
||||
def parse_metadata(meta_xml_file):
|
||||
"""Parse Sentinel-1 metadata XML file
|
||||
|
||||
Args:
|
||||
meta_xml_file: Metadata XML file path
|
||||
|
||||
Returns:
|
||||
dict: Dictionary containing key metadata information
|
||||
"""
|
||||
record = {}
|
||||
|
||||
dom = xml.dom.minidom.parse(meta_xml_file) # Get sensing start time
|
||||
sensing_start = dom.getElementsByTagName('startTime')[0].firstChild.data
|
||||
|
||||
product_uri = meta_xml_file.name.split('/')[0]
|
||||
|
||||
record.update({
|
||||
'product_uri': product_uri,
|
||||
'sensing_start': sensing_start,
|
||||
})
|
||||
|
||||
|
||||
return record
|
||||
|
||||
def convert_footprint_to_wkt(footprint):
|
||||
"""Convert footprint string to WKT format"""
|
||||
coords = footprint.strip().split(' ')
|
||||
wkt_coords = []
|
||||
for coord in coords:
|
||||
lat, lon = coord.split(',')
|
||||
wkt_coords.append(f"{lon} {lat}")
|
||||
return f"MULTIPOLYGON ((({','.join(wkt_coords)})))"
|
||||
|
||||
def zip2rec(fn_zip):
|
||||
id_img = os.path.splitext(os.path.basename(fn_zip))[0]
|
||||
archive = zipfile.ZipFile(fn_zip, 'r')
|
||||
xml_files = [f for f in archive.namelist() if f.endswith('-001.xml')]
|
||||
if not xml_files:
|
||||
raise FileNotFoundError(f"No XML file ending with '-001.xml' found in {fn_zip}")
|
||||
fn_xml = archive.open(xml_files[0])
|
||||
rec = parse_metadata(fn_xml)
|
||||
import pdb; pdb.set_trace()
|
||||
# rec['geometry'] = wkt.loads(rec['geom_wkt'])
|
||||
thumb = archive.open(os.path.join(f'{id_img}.SAFE', 'preview', 'quick-look.png'))
|
||||
thumb = iio.imread(thumb)
|
||||
rec['thumb'] = thumb
|
||||
return rec
|
||||
|
||||
def build_catalog(path, fn='catalog'):
|
||||
'''
|
||||
fn: filename or None
|
||||
'''
|
||||
list_fnames = glob(os.path.join(path, 'S2*.zip'))
|
||||
|
||||
list_rec = []
|
||||
for fn_zip in list_fnames:
|
||||
rec = zip2rec(fn_zip)
|
||||
list_rec.append(rec)
|
||||
|
||||
gdf = gpd.GeoDataFrame(list_rec, crs='EPSG:4326').drop(columns='geom_wkt')
|
||||
if fn is not None:
|
||||
fn_geojson = os.path.join(path, f"{fn}.geojson")
|
||||
gdf.to_file(fn_geojson, driver='GeoJSON')
|
||||
return fn_geojson
|
||||
else:
|
||||
return gdf
|
||||
|
||||
def make_full_name(rec, band):
|
||||
dict_bands = {
|
||||
'VV': '001',
|
||||
'VH': '002',
|
||||
}
|
||||
parts = rec['product_uri'].split('_')
|
||||
|
||||
satellite = parts[0].lower() # S1A -> s1a
|
||||
mode = parts[1].lower() # IW -> iw
|
||||
product_type = parts[2][:3].lower() # GRDH -> grd
|
||||
polarization = band.lower() # Assume polarization mode is VV
|
||||
start_time = parts[4].lower() # Start time
|
||||
end_time = parts[5].lower() # End time
|
||||
id1 = parts[6].lower() # 058175
|
||||
id2 = parts[7].lower() # 072FF2
|
||||
fixed_part = dict_bands[band] # Replace fixed part with 001
|
||||
|
||||
# Concatenate to target format
|
||||
file_name = f"{satellite}-{mode}-{product_type}-{polarization}-{start_time}-{end_time}-{id1}-{id2}-{fixed_part}.tiff"
|
||||
|
||||
fn_template = os.path.join(
|
||||
rec['product_uri'], 'measurement', file_name
|
||||
)
|
||||
return fn_template
|
||||
|
||||
def warp(
|
||||
ds, outputBounds,
|
||||
outputBoundsSRS='EPSG:4326',
|
||||
xRes=10, yRes=10, targetAlignedPixels=True,
|
||||
**kwargs,
|
||||
):
|
||||
options_warp = gdal.WarpOptions(
|
||||
format="MEM",
|
||||
outputBounds=outputBounds,
|
||||
outputBoundsSRS=outputBoundsSRS,
|
||||
xRes=xRes, yRes=yRes, targetAlignedPixels=targetAlignedPixels,
|
||||
**kwargs,
|
||||
)
|
||||
ds_warp = gdal.Warp('', ds, options=options_warp)
|
||||
return ds_warp
|
||||
|
||||
def get_ndarray(
|
||||
ds, outputBounds,
|
||||
outputBoundsSRS='EPSG:4326',
|
||||
xRes=10, yRes=10, targetAlignedPixels=True,
|
||||
**kwargs,
|
||||
):
|
||||
ds_warp = warp(
|
||||
ds, outputBounds,
|
||||
outputBoundsSRS='EPSG:4326',
|
||||
xRes=10, yRes=10, targetAlignedPixels=True,
|
||||
**kwargs
|
||||
)
|
||||
arr = ds_warp.ReadAsArray()
|
||||
ds_warp = None
|
||||
return arr
|
||||
|
||||
158
tools/pretraining_data_builder/rsi_process/utils_s2.py
Normal file
158
tools/pretraining_data_builder/rsi_process/utils_s2.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import xml.dom.minidom
|
||||
import os
|
||||
from glob import glob
|
||||
import zipfile
|
||||
from shapely import wkt
|
||||
import geopandas as gpd
|
||||
from osgeo import gdal
|
||||
import imageio.v2 as iio
|
||||
|
||||
def parse_metadata(meta_xml_file):
|
||||
"""Parse Sentinel-2 metadata XML file
|
||||
|
||||
Args:
|
||||
meta_xml_file: Path to metadata XML file
|
||||
|
||||
Returns:
|
||||
dict: Metadata information including sensing time, product URI, etc.
|
||||
"""
|
||||
record = {}
|
||||
try:
|
||||
dom = xml.dom.minidom.parse(meta_xml_file)
|
||||
|
||||
# Get sensing start time
|
||||
sensing_start = dom.getElementsByTagName('DATATAKE_SENSING_START')[0].firstChild.data[0:10]
|
||||
|
||||
# Get product URI and image paths
|
||||
product_uri = dom.getElementsByTagName('PRODUCT_URI')[0].firstChild.data
|
||||
|
||||
image_file = dom.getElementsByTagName('IMAGE_FILE')[0].firstChild.data
|
||||
items = image_file.split('/')
|
||||
granule_path = items[1]
|
||||
img_name = items[4].split('_')[0] + '_' + items[4].split('_')[1]
|
||||
|
||||
# Get footprint
|
||||
footprint = dom.getElementsByTagName('EXT_POS_LIST')[0].firstChild.data
|
||||
geom_wkt = convert_footprint_to_wkt(footprint)
|
||||
|
||||
# Get cloud coverage info
|
||||
cloud_coverage = float(dom.getElementsByTagName('Cloud_Coverage_Assessment')[0].firstChild.data)
|
||||
cloud_shadow = float(dom.getElementsByTagName('CLOUD_SHADOW_PERCENTAGE')[0].firstChild.data)
|
||||
medium_clouds = float(dom.getElementsByTagName('MEDIUM_PROBA_CLOUDS_PERCENTAGE')[0].firstChild.data)
|
||||
high_clouds = float(dom.getElementsByTagName('HIGH_PROBA_CLOUDS_PERCENTAGE')[0].firstChild.data)
|
||||
|
||||
record.update({
|
||||
'product_uri': product_uri,
|
||||
'sensing_start': sensing_start,
|
||||
'granule_path': granule_path,
|
||||
'img_name': img_name,
|
||||
'cloud_cover': cloud_coverage,
|
||||
'cloud_shadow': cloud_shadow,
|
||||
'medium_clouds': medium_clouds,
|
||||
'high_clouds': high_clouds,
|
||||
'geom_wkt': geom_wkt
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
print(f'Failed to parse XML: {e}')
|
||||
|
||||
return record
|
||||
|
||||
def convert_footprint_to_wkt(footprint):
|
||||
"""Convert footprint string to WKT format"""
|
||||
coords = footprint.strip().split(' ')
|
||||
wkt_coords = []
|
||||
for i in range(0, len(coords), 2):
|
||||
wkt_coords.append(f"{coords[i+1]} {coords[i]}")
|
||||
return f"MULTIPOLYGON ((({','.join(wkt_coords)})))"
|
||||
|
||||
def zip2rec(fn_zip):
|
||||
id_img = os.path.splitext(os.path.basename(fn_zip))[0]
|
||||
archive = zipfile.ZipFile(fn_zip, 'r')
|
||||
fn_xml = archive.open(os.path.join(f'{id_img}.SAFE', 'MTD_MSIL2A.xml'))
|
||||
rec = parse_metadata(fn_xml)
|
||||
rec['geometry'] = wkt.loads(rec['geom_wkt'])
|
||||
thumb = archive.open(os.path.join(f'{id_img}.SAFE', f'{id_img}-ql.jpg'))
|
||||
thumb = iio.imread(thumb)
|
||||
rec['thumb'] = thumb
|
||||
return rec
|
||||
|
||||
def build_catalog(path, fn='catalog'):
|
||||
'''
|
||||
fn: filename or None
|
||||
'''
|
||||
list_fnames = glob(os.path.join(path, 'S2*.zip'))
|
||||
|
||||
list_rec = []
|
||||
for fn_zip in list_fnames:
|
||||
rec = zip2rec(fn_zip)
|
||||
list_rec.append(rec)
|
||||
|
||||
gdf = gpd.GeoDataFrame(list_rec, crs='EPSG:4326').drop(columns='geom_wkt')
|
||||
if fn is not None:
|
||||
fn_geojson = os.path.join(path, f"{fn}.geojson")
|
||||
gdf.to_file(fn_geojson, driver='GeoJSON')
|
||||
return fn_geojson
|
||||
else:
|
||||
return gdf
|
||||
|
||||
def make_full_name(rec, band):
|
||||
dict_bands = {
|
||||
'B2': ['B02', '10m'],
|
||||
'B3': ['B03', '10m'],
|
||||
'B4': ['B04', '10m'],
|
||||
'B8': ['B08', '10m'],
|
||||
'B5': ['B05', '20m'],
|
||||
'B6': ['B06', '20m'],
|
||||
'B7': ['B07', '20m'],
|
||||
'B8A': ['B8A', '20m'],
|
||||
'B11': ['B11', '20m'],
|
||||
'B12': ['B12', '20m'],
|
||||
'SCL': ['SCL', '20m'],
|
||||
}
|
||||
fn_template = os.path.join(
|
||||
'{p0}', 'GRANULE',
|
||||
'{p1}', 'IMG_DATA', "R{p2}",
|
||||
'{p3}_{p4}_{p2}.jp2'
|
||||
)
|
||||
return fn_template.format(**{
|
||||
'p0': rec['product_uri'],
|
||||
'p0b': rec['product_uri'].split('.')[0],
|
||||
'p1': rec['granule_path'],
|
||||
'p2': dict_bands[band][1],
|
||||
'p3': rec['img_name'],
|
||||
'p4': dict_bands[band][0],
|
||||
})
|
||||
|
||||
def warp(
|
||||
ds, outputBounds,
|
||||
outputBoundsSRS='EPSG:4326',
|
||||
xRes=10, yRes=10, targetAlignedPixels=True,
|
||||
**kwargs,
|
||||
):
|
||||
options_warp = gdal.WarpOptions(
|
||||
format="MEM",
|
||||
outputBounds=outputBounds,
|
||||
outputBoundsSRS=outputBoundsSRS,
|
||||
xRes=xRes, yRes=yRes, targetAlignedPixels=targetAlignedPixels,
|
||||
**kwargs,
|
||||
)
|
||||
ds_warp = gdal.Warp('', ds, options=options_warp)
|
||||
return ds_warp
|
||||
|
||||
def get_ndarray(
|
||||
ds, outputBounds,
|
||||
outputBoundsSRS='EPSG:4326',
|
||||
xRes=10, yRes=10, targetAlignedPixels=True,
|
||||
**kwargs,
|
||||
):
|
||||
ds_warp = warp(
|
||||
ds, outputBounds,
|
||||
outputBoundsSRS='EPSG:4326',
|
||||
xRes=10, yRes=10, targetAlignedPixels=True,
|
||||
**kwargs
|
||||
)
|
||||
arr = ds_warp.ReadAsArray()
|
||||
ds_warp = None
|
||||
return arr
|
||||
|
||||
237
tools/pretraining_data_builder/rsi_process/utils_wv.py
Normal file
237
tools/pretraining_data_builder/rsi_process/utils_wv.py
Normal file
@@ -0,0 +1,237 @@
|
||||
import os
|
||||
from osgeo import gdal
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
def parse_metadata(meta_xml_file):
|
||||
"""Parse the WorldView metadata XML file
|
||||
|
||||
Args:
|
||||
meta_xml_file: Metadata XML file path
|
||||
|
||||
Returns:
|
||||
dict: Dictionary containing key metadata information
|
||||
"""
|
||||
record = {}
|
||||
|
||||
try:
|
||||
tree = ET.parse(meta_xml_file)
|
||||
root = tree.getroot()
|
||||
|
||||
ns = {'imd': root.tag.split('}')[0].strip('{')}
|
||||
|
||||
# Get basic information
|
||||
record['satellite_id'] = root.find('.//imd:satelliteID', ns).text
|
||||
record['product_type'] = root.find('.//imd:productType', ns).text
|
||||
|
||||
# Get acquisition time
|
||||
acq_time = root.find('.//imd:firstLineTime', ns).text
|
||||
record['sensing_start'] = datetime.strptime(acq_time, '%Y-%m-%dT%H:%M:%S.%fZ')
|
||||
|
||||
# Get solar angle
|
||||
record['sun_azimuth'] = float(root.find('.//imd:meanSunAz', ns).text)
|
||||
record['sun_elevation'] = float(root.find('.//imd:meanSunEl', ns).text)
|
||||
|
||||
# Get satellite angle
|
||||
record['satellite_azimuth'] = float(root.find('.//imd:meanSatAz', ns).text)
|
||||
record['satellite_elevation'] = float(root.find('.//imd:meanSatEl', ns).text)
|
||||
|
||||
# Get cloud cover
|
||||
cloud_cover = root.find('.//imd:cloudCover', ns)
|
||||
record['cloud_cover'] = float(cloud_cover.text) if cloud_cover is not None else None
|
||||
|
||||
# Get image range
|
||||
record['ul_lon'] = float(root.find('.//imd:ULLon', ns).text)
|
||||
record['ul_lat'] = float(root.find('.//imd:ULLat', ns).text)
|
||||
record['ur_lon'] = float(root.find('.//imd:URLon', ns).text)
|
||||
record['ur_lat'] = float(root.find('.//imd:URLat', ns).text)
|
||||
record['ll_lon'] = float(root.find('.//imd:LLLon', ns).text)
|
||||
record['ll_lat'] = float(root.find('.//imd:LLLat', ns).text)
|
||||
record['lr_lon'] = float(root.find('.//imd:LRLon', ns).text)
|
||||
record['lr_lat'] = float(root.find('.//imd:LRLat', ns).text)
|
||||
|
||||
# Build WKT format geometry information
|
||||
record['geom_wkt'] = create_footprint_wkt(record)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error parsing metadata: {str(e)}")
|
||||
return None
|
||||
|
||||
return record
|
||||
|
||||
def create_footprint_wkt(record):
|
||||
"""Create a WKT format polygon based on corner coordinates
|
||||
|
||||
Args:
|
||||
record: Dictionary containing corner coordinates
|
||||
|
||||
Returns:
|
||||
str: WKT format polygon string
|
||||
"""
|
||||
coords = [
|
||||
(record['ul_lon'], record['ul_lat']),
|
||||
(record['ur_lon'], record['ur_lat']),
|
||||
(record['lr_lon'], record['lr_lat']),
|
||||
(record['ll_lon'], record['ll_lat']),
|
||||
(record['ul_lon'], record['ul_lat'])
|
||||
]
|
||||
|
||||
coord_str = ', '.join([f"{lon} {lat}" for lon, lat in coords])
|
||||
return f"POLYGON(({coord_str}))"
|
||||
|
||||
def get_band_info(ds):
|
||||
"""Get the band information of the image
|
||||
|
||||
Args:
|
||||
ds: GDAL dataset
|
||||
|
||||
Returns:
|
||||
list: Band information list
|
||||
"""
|
||||
bands = []
|
||||
for i in range(1, ds.RasterCount + 1):
|
||||
band = ds.GetRasterBand(i)
|
||||
band_info = {
|
||||
'band_number': i,
|
||||
'data_type': gdal.GetDataTypeName(band.DataType),
|
||||
'nodata_value': band.GetNoDataValue()
|
||||
}
|
||||
bands.append(band_info)
|
||||
return bands
|
||||
|
||||
def read_as_array(ds, window=None):
|
||||
"""Read image data as a numpy array
|
||||
|
||||
Args:
|
||||
ds: GDAL dataset
|
||||
window: Read window, format as (xoff, yoff, xsize, ysize)
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: Image data array
|
||||
"""
|
||||
if window is None:
|
||||
return ds.ReadAsArray()
|
||||
else:
|
||||
xoff, yoff, xsize, ysize = window
|
||||
return ds.ReadAsArray(xoff, yoff, xsize, ysize)
|
||||
|
||||
def get_image_info(fn_img):
|
||||
"""Get basic information of WorldView image
|
||||
|
||||
Args:
|
||||
fn_img: Image file path
|
||||
|
||||
Returns:
|
||||
dict: Image information dictionary
|
||||
"""
|
||||
ds = gdal.Open(fn_img)
|
||||
if ds is None:
|
||||
raise Exception(f"Cannot open {fn_img}")
|
||||
|
||||
info = {
|
||||
'width': ds.RasterXSize,
|
||||
'height': ds.RasterYSize,
|
||||
'bands': ds.RasterCount,
|
||||
'projection': ds.GetProjection(),
|
||||
'geotransform': ds.GetGeoTransform(),
|
||||
'band_info': get_band_info(ds)
|
||||
}
|
||||
|
||||
xml_file = fn_img.replace('.tif', '.xml')
|
||||
if os.path.exists(xml_file):
|
||||
metadata = parse_metadata(xml_file)
|
||||
if metadata:
|
||||
info.update(metadata)
|
||||
|
||||
ds = None
|
||||
return info
|
||||
|
||||
def calculate_stats(fn_img, percentiles=[2, 98]):
|
||||
"""Calculate the statistics of the image
|
||||
|
||||
Args:
|
||||
fn_img: Image file path
|
||||
percentiles: List of percentiles
|
||||
|
||||
Returns:
|
||||
dict: Statistics dictionary
|
||||
"""
|
||||
ds = gdal.Open(fn_img)
|
||||
stats = {}
|
||||
|
||||
for i in range(1, ds.RasterCount + 1):
|
||||
band = ds.GetRasterBand(i)
|
||||
array = band.ReadAsArray()
|
||||
valid_data = array[array != band.GetNoDataValue()]
|
||||
|
||||
stats[f'band_{i}'] = {
|
||||
'min': np.min(valid_data),
|
||||
'max': np.max(valid_data),
|
||||
'mean': np.mean(valid_data),
|
||||
'std': np.std(valid_data),
|
||||
'percentiles': {
|
||||
p: np.percentile(valid_data, p)
|
||||
for p in percentiles
|
||||
}
|
||||
}
|
||||
|
||||
ds = None
|
||||
return stats
|
||||
|
||||
def create_quicklook(fn_img, output_file, size=(1024, 1024)):
|
||||
"""Create a thumbnail
|
||||
|
||||
Args:
|
||||
fn_img: Image file path
|
||||
output_file: Output file path
|
||||
size: Output image size
|
||||
"""
|
||||
ds = gdal.Open(fn_img)
|
||||
|
||||
if ds.RasterCount >= 3:
|
||||
r = ds.GetRasterBand(1).ReadAsArray()
|
||||
g = ds.GetRasterBand(2).ReadAsArray()
|
||||
b = ds.GetRasterBand(3).ReadAsArray()
|
||||
|
||||
def stretch(arr):
|
||||
p2, p98 = np.percentile(arr[arr > 0], (2, 98))
|
||||
return np.clip((arr - p2) / (p98 - p2) * 255, 0, 255).astype(np.uint8)
|
||||
|
||||
rgb = np.dstack([stretch(r), stretch(g), stretch(b)])
|
||||
|
||||
from PIL import Image
|
||||
img = Image.fromarray(rgb)
|
||||
img.thumbnail(size)
|
||||
img.save(output_file)
|
||||
|
||||
ds = None
|
||||
|
||||
def warp(ds, outputBounds,
|
||||
outputBoundsSRS='EPSG:4326',
|
||||
xRes=2, yRes=2,
|
||||
targetAlignedPixels=True,
|
||||
**kwargs):
|
||||
"""Reprojection and resampling
|
||||
|
||||
Args:
|
||||
ds: GDAL dataset
|
||||
outputBounds: Output range
|
||||
outputBoundsSRS: Output coordinate system
|
||||
xRes, yRes: Output resolution
|
||||
targetAlignedPixels: Whether to align pixels
|
||||
**kwargs: Other GDAL.Warp parameters
|
||||
|
||||
Returns:
|
||||
GDAL dataset
|
||||
"""
|
||||
options_warp = gdal.WarpOptions(
|
||||
format="MEM",
|
||||
outputBounds=outputBounds,
|
||||
outputBoundsSRS=outputBoundsSRS,
|
||||
xRes=xRes, yRes=yRes,
|
||||
targetAlignedPixels=targetAlignedPixels,
|
||||
**kwargs
|
||||
)
|
||||
ds_warp = gdal.Warp('', ds, options=options_warp)
|
||||
return ds_warp
|
||||
11
tools/pretraining_data_builder/run_data_builder.sh
Normal file
11
tools/pretraining_data_builder/run_data_builder.sh
Normal file
@@ -0,0 +1,11 @@
|
||||
#! /bin/bash
|
||||
source activate data_builder
|
||||
export USERNAME=your_username
|
||||
export PASSWORD=your_password
|
||||
export API_KEY=your_api_key
|
||||
|
||||
export PYTHONPATH=$PYTHONPATH:$(pwd)
|
||||
|
||||
LMDB_PATH=your_lmdb_path
|
||||
|
||||
python rsi_pipeline/data_builder.py $LMDB_PATH
|
||||
Reference in New Issue
Block a user