Skip to content
Open

Mps #11

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,15 @@ work_dirs

outputs
experiments
renders
renders

.vscode
thirdparty/gsplat/
thirdparty/gsplat-mps/**/*.bak
thirdparty/gsplat-mps/.clang-format
thirdparty/gsplat-mps/.clangd_template
thirdparty/gsplat-mps/.github/
thirdparty/gsplat-mps/.gitignore
thirdparty/gsplat-mps/.gitmodules
thirdparty/gsplat-mps/docs/
thirdparty/gsplat-mps/examples/
23 changes: 23 additions & 0 deletions 1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pickle
from collections import Counter
from pathlib import Path

pkl = Path("data/main_mt/MTGS/road_block-331220_4690660_331190_4690710/video_scene_dict.pkl")
data = pickle.load(open(pkl, "rb"))

vehicle_ids = Counter()
travel_counts = Counter()

for video_token, v in data.items():
travel_id = int(video_token.split("-")[-1])
travel_counts[travel_id] += 1
for f in v["frame_infos"]:
for name, token in zip(f["gt_names"], f["track_tokens"]):
if name == "vehicle":
vehicle_ids[token] += 1

print("travel_ids:", sorted(travel_counts.items())[:20], "total:", len(travel_counts))
print("vehicles total:", len(vehicle_ids))
print("top vehicles:")
for token, count in vehicle_ids.most_common(20):
print(token, count)
145 changes: 145 additions & 0 deletions 2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import pickle
from collections import Counter, defaultdict
from pathlib import Path
import math

import numpy as np
from pyquaternion import Quaternion

PKL_PATH = Path("data/MTGS/road_block-331220_4690660_331190_4690710/video_scene_dict.pkl")
FRAME_START = 0
FRAME_END = None # None=到最后
DIRECTION_COS_THRESHOLD = 0.7


def classify_direction(ego_disp, veh_disp, threshold):
ego_norm = np.linalg.norm(ego_disp)
veh_norm = np.linalg.norm(veh_disp)
if ego_norm < 1e-6 or veh_norm < 1e-6:
return 0.0, 0.0, "unknown"
cos_sim = float(np.dot(ego_disp, veh_disp) / (ego_norm * veh_norm))
cos_sim = max(min(cos_sim, 1.0), -1.0)
angle = float(np.degrees(np.arccos(cos_sim)))
if cos_sim >= threshold:
label = "same"
elif cos_sim <= -threshold:
label = "opposite"
else:
label = "cross"
return cos_sim, angle, label


def analyze_video(video_token, frames):
if not frames:
return
if FRAME_END is None:
end_idx = len(frames) - 1
else:
end_idx = min(FRAME_END, len(frames) - 1)

stats = defaultdict(lambda: {"min": 1e9, "sum": 0.0, "count": 0, "min_frame": None})
nearest_counts = Counter()
track_positions = defaultdict(list)
ego_positions = []

for idx, f in enumerate(frames[FRAME_START:end_idx + 1], start=FRAME_START):
best = None
e2g_trans = np.array(f["ego2global_translation"], dtype=np.float32)
e2g_rot = Quaternion(f["ego2global_rotation"]).rotation_matrix
ego_positions.append(e2g_trans)
for name, token, box in zip(f["gt_names"], f["track_tokens"], f["gt_boxes"]):
if name != "vehicle":
continue
center_ego = np.array(box[:3], dtype=np.float32)
center_global = center_ego @ e2g_rot.T + e2g_trans
track_positions[token].append(center_global)
x, y, z = box[:3]
if x <= 0: # 只看前方车辆;不需要可删
continue
d = math.sqrt(x * x + y * y + z * z)
s = stats[token]
s["sum"] += d
s["count"] += 1
if d < s["min"]:
s["min"] = d
s["min_frame"] = idx
if best is None or d < best[0]:
best = (d, token)
if best:
nearest_counts[best[1]] += 1

ego_disp = np.zeros(3, dtype=np.float32)
if len(ego_positions) >= 2:
ego_disp = ego_positions[-1] - ego_positions[0]

moving = []
static = []
moving_direction_counts = Counter()
for token, centers in track_positions.items():
if len(centers) < 2:
continue
veh_disp = centers[-1] - centers[0]
disp = float(np.linalg.norm(veh_disp))
cos_sim, angle, label = classify_direction(ego_disp, veh_disp, DIRECTION_COS_THRESHOLD)
if disp < 3.0:
static.append((token, disp, len(centers), cos_sim, angle, label))
else:
moving.append((token, disp, len(centers), cos_sim, angle, label))
moving_direction_counts[label] += 1

travel_id = int(video_token.split("-")[-1])
print(f"\ntravel_id={travel_id} video_token={video_token} frames={end_idx - FRAME_START + 1}")

print("nearest_counts top:")
for token, cnt in nearest_counts.most_common(10):
s = stats[token]
avg = s["sum"] / s["count"]
print(
token,
"nearest_frames",
cnt,
"avg_dist",
round(avg, 2),
"min_dist",
round(s["min"], 2),
"min_frame",
s["min_frame"],
)

print("moving top:")
for token, disp, count, cos_sim, angle, label in sorted(moving, key=lambda x: -x[1])[:10]:
print(
token,
"disp",
round(disp, 2),
"frames",
count,
"cos",
round(cos_sim, 2),
"angle",
round(angle, 1),
label,
)

print("moving direction counts:", dict(moving_direction_counts))

print("static top:")
for token, disp, count, cos_sim, angle, label in sorted(static, key=lambda x: x[1])[:10]:
print(
token,
"disp",
round(disp, 2),
"frames",
count,
"cos",
round(cos_sim, 2),
"angle",
round(angle, 1),
label,
)


data = pickle.load(open(PKL_PATH, "rb"))
video_tokens = sorted(data.keys(), key=lambda vt: int(vt.split("-")[-1]))
for video_token in video_tokens:
analyze_video(video_token, data[video_token]["frame_infos"])
9 changes: 3 additions & 6 deletions mtgs/dataset/nuplan_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,7 @@ def _generate_dataparser_outputs(self, split="train"):
pose = ego2global @ cam2ego

image_filenames.append(Path(
os.path.join(
video_scene.raw_image_path,
cam_info['data_path']
)
video_scene.runtime_image_path(cam_info['data_path'])
))

if self.config.undistort_images == "optimal":
Expand Down Expand Up @@ -352,7 +349,7 @@ def _generate_dataparser_outputs(self, split="train"):

lidar2cams.append(lidar2cam) # opencv camera
lidar_paths.append(
os.path.join(video_scene.raw_lidar_path, info['lidar_path'])
video_scene.runtime_lidar_path(info['lidar_path'])
)

v_adjust_factors.append(
Expand Down Expand Up @@ -397,7 +394,7 @@ def _generate_dataparser_outputs(self, split="train"):
poses[:, :3, 3] *= self.config.scale_factor

if self.config.load_cam_optim_from is not None:
model = torch.load(self.config.load_cam_optim_from, map_location='cpu')
model = torch.load(self.config.load_cam_optim_from, map_location='cpu', weights_only=False)
pose_adj = model['pipeline'][self.config.cam_optim_key]
if pose_adj.shape[0] != poses.shape[0]:
CONSOLE.log(f"[WARNING] pose_adj shape {pose_adj.shape[0]} does not match poses shape {poses.shape[0]}")
Expand Down
173 changes: 173 additions & 0 deletions mtgs/scene_model/gaussian_model/rigid_node_mirrored.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
#-------------------------------------------------------------------------------#
# MTGS: Multi-Traversal Gaussian Splatting (https://arxiv.org/abs/2503.12552) #
# Source code: https://github.com/OpenDriveLab/MTGS #
# Copyright (c) OpenDriveLab. All rights reserved. #
#-------------------------------------------------------------------------------#
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Type, Union, Any

import torch
from torch.nn import Parameter, Module

try:
from gsplat.cuda._wrapper import spherical_harmonics
except ImportError:
print("Please install gsplat>=1.0.0")

from .utils import quat_mult, quat_to_rotmat

from .rigid_node import RigidSubModelConfig, RigidSubModel


def flip_spherical_harmonics(coeff, sh_degree=3):
"""
Flip the spherical harmonics coefficients along the y-axis.

Args:
coeff (torch.Tensor): A tensor of shape [N, 16, 3], where N is the number of Gaussians,
16 is the number of spherical harmonics coefficients (up to degree l=3),
and 3 is the feature dimension.

Returns:
torch.Tensor: The flipped spherical harmonics coefficients.
"""
# Indices corresponding to m < 0 for l up to 3
if sh_degree == 0:
return coeff
elif sh_degree == 1:
indices_m_negative = [1]
elif sh_degree == 2:
indices_m_negative = [1, 4, 5]
elif sh_degree == 3:
indices_m_negative = [1, 4, 5, 9, 10, 11]
else:
raise ValueError(f"Unsupported SH degree: {sh_degree}")

# Create a flip factor tensor of ones and minus ones
flip_factors = torch.ones(coeff.shape[1], device=coeff.device)
flip_factors[indices_m_negative] = -1

# Reshape flip_factors to [1, 16, 1] for broadcasting
flip_factors = flip_factors.view(1, -1, 1)

# Apply the flip factors to the coefficients
flipped_coeff = coeff * flip_factors

return flipped_coeff

@dataclass
class MirroredRigidSubModelConfig(RigidSubModelConfig):
"""Gaussian Splatting Model Config"""

_target: Type = field(default_factory=lambda: MirroredRigidSubModel)
mirror_static: bool = True


class MirroredRigidSubModel(RigidSubModel):

config: MirroredRigidSubModelConfig

def get_means(self, quat_cur_frame, trans_cur_frame):
if self.is_static and not self.config.mirror_static:
return super().get_means(quat_cur_frame, trans_cur_frame)

local_means: torch.Tensor = self.gauss_params['means']
local_means_flipped = local_means * local_means.new_tensor([1, -1, 1]).view(1, 3)
local_means = torch.cat([local_means, local_means_flipped], dim=0)

rot_cur_frame = quat_to_rotmat(quat_cur_frame)
global_means = local_means @ rot_cur_frame.T + trans_cur_frame
return global_means

def get_quats(self, quat_cur_frame, trans_cur_frame):
if self.is_static and not self.config.mirror_static:
return super().get_quats(quat_cur_frame, trans_cur_frame)

local_quats = self.quats / self.quats.norm(dim=-1, keepdim=True)
flip_tensor = local_quats.new_tensor([1, -1, 1, -1]).view(1, 4)
local_quats_flipped = local_quats * flip_tensor
local_quats = torch.cat([local_quats, local_quats_flipped], dim=0)
global_quats = quat_mult(quat_cur_frame, local_quats)

return global_quats

def get_scales(self):
if self.is_static and not self.config.mirror_static:
return super().get_scales()
scales = torch.exp(self.scales)
return torch.cat([scales, scales], dim=0)

def get_rgbs(self, camera_to_worlds, quat_cur_frame=None, trans_cur_frame=None, timestamp=None, global_current_means=None):
if self.is_static and not self.config.mirror_static:
return super().get_rgbs(camera_to_worlds, quat_cur_frame, trans_cur_frame, timestamp, global_current_means)
cam_obj_yaw = self.get_cam_obj_yaw(camera_to_worlds, quat_cur_frame)
true_features_dc = self.get_true_features_dc(timestamp, cam_obj_yaw)
colors = torch.cat((true_features_dc[:, None, :], self.features_rest), dim=1)
colors = colors.unsqueeze(0).repeat(2, 1, 1, 1)
colors[1, ...] = flip_spherical_harmonics(colors[1, ...], self.sh_degree)
colors = colors.view(-1, colors.shape[-2], 3)

if self.sh_degree > 0:
viewdirs = self.get_means(quat_cur_frame, trans_cur_frame) if global_current_means is None else global_current_means
viewdirs = viewdirs.detach() - camera_to_worlds[..., :3, 3] # (N, 3)
viewdirs = viewdirs / viewdirs.norm(dim=-1, keepdim=True)
n = min(self.step // self.ctrl_config.sh_degree_interval, self.sh_degree)
rgbs = spherical_harmonics(n, viewdirs, colors)
rgbs = torch.clamp(rgbs + 0.5, 0.0, 1.0)
else:
rgbs = torch.sigmoid(colors[:, 0, :])

return rgbs

def get_opacity(self):
if self.is_static and not self.config.mirror_static:
return super().get_opacity()
return torch.sigmoid(self.gauss_params['opacities']).squeeze(-1).repeat(2)

def get_gaussian_params(self, travel_id=None, frame_idx=None, timestamp=None, **kwargs):
if self.is_static and not self.config.mirror_static:
return super().get_gaussian_params(travel_id, frame_idx, timestamp, **kwargs)
if travel_id != self.travel_id or (frame_idx is None and timestamp is None):
return None

if frame_idx is not None:
assert frame_idx < self.num_frames

quat_cur_frame, trans_cur_frame = self.get_object_pose(frame_idx, timestamp)
if quat_cur_frame is None or trans_cur_frame is None:
return None

if timestamp is None:
timestamp = self.dataframe_dict["frame_timestamps"][frame_idx]

true_features_dc = self.get_true_features_dc(timestamp)
colors = torch.cat((true_features_dc[:, None, :], self.features_rest), dim=1)
colors = colors.unsqueeze(0).repeat(2, 1, 1, 1)
colors[1, ...] = flip_spherical_harmonics(colors[1, ...], self.sh_degree)
colors = colors.view(-1, 16, 3)

return {
"means": self.get_means(quat_cur_frame, trans_cur_frame),
"scales": self.scales.repeat(2, 1),
"quats": self.get_quats(quat_cur_frame, trans_cur_frame),
"features_dc": colors[:, 0, :],
"features_rest": colors[:, 1:, :],
"opacities": self.opacities.repeat(2, 1),
}

def update_statistics(self, xys_grad: torch.Tensor, radii: torch.Tensor):
if self.is_static and not self.config.mirror_static:
return super().update_statistics(xys_grad, radii)

if xys_grad is None or radii is None:
self.xys_grad = None
self.radii = None
return

N = xys_grad.shape[0] // 2
assert N == self.num_points
xys_grad = xys_grad.view(2, N).max(dim=0).values
radii = radii.view(2, N).max(dim=0).values

self.xys_grad = xys_grad
self.radii = radii
Loading