diff --git a/MantleAPI/include/MantleAPI/Traffic/entity_helper.h b/MantleAPI/include/MantleAPI/Traffic/entity_helper.h
index a09a4534b122807c65fb6dda7f586fb981a05d1b..e333f386481c59deeadbe099083ff257c0484a21 100644
--- a/MantleAPI/include/MantleAPI/Traffic/entity_helper.h
+++ b/MantleAPI/include/MantleAPI/Traffic/entity_helper.h
@@ -21,17 +21,39 @@
 
 namespace mantle_api
 {
-
-inline void SetSpeed(mantle_api::IEntity* entity, const units::velocity::meters_per_second_t &velocity)
+/// Set the speed (i.e. the length of the velocity vector) of an entity. The direction of the velocity vector does not
+/// change. If the entity's speed is zero, then the function uses the orientation of the entity to derive the direction
+/// of the velocity vector.
+///
+/// @param entity   pointer to the entity
+/// @param velocity the new speed to set
+inline void SetSpeed(mantle_api::IEntity* entity, const units::velocity::meters_per_second_t& velocity)
 {
-  auto orientation = entity->GetOrientation();
-
-  auto cos_elevation = units::math::cos(orientation.pitch);
-  mantle_api::Vec3<units::velocity::meters_per_second_t> velocity_vector{velocity * units::math::cos(orientation.yaw) * cos_elevation,
-                                    velocity * units::math::sin(orientation.yaw) * cos_elevation,
-                                    velocity * -units::math::sin(orientation.pitch)};
-
-  entity->SetVelocity(velocity_vector);
+  if (entity == nullptr)
+  {
+    throw std::runtime_error("entity is null");
+  }
+
+  auto current_velocity = entity->GetVelocity();
+
+  if (current_velocity.Length()() == 0)
+  {
+    auto orientation = entity->GetOrientation();
+
+    auto cos_elevation = units::math::cos(orientation.pitch);
+    mantle_api::Vec3<units::velocity::meters_per_second_t> velocity_vector{
+        velocity * units::math::cos(orientation.yaw) * cos_elevation,
+        velocity * units::math::sin(orientation.yaw) * cos_elevation,
+        velocity * -units::math::sin(orientation.pitch)};
+
+    entity->SetVelocity(velocity_vector);
+  }
+  else
+  {
+    mantle_api::Vec3<units::velocity::meters_per_second_t> velocity_vector =
+        (current_velocity / current_velocity.Length()()) * velocity();
+    entity->SetVelocity(velocity_vector);
+  }
 }
 
 }  // namespace mantle_api