diff --git a/MantleAPI/include/MantleAPI/Traffic/entity_helper.h b/MantleAPI/include/MantleAPI/Traffic/entity_helper.h
index a09a4534b122807c65fb6dda7f586fb981a05d1b..a6d85b6bb2acb8ba11b994807e141a8c22284290 100644
--- a/MantleAPI/include/MantleAPI/Traffic/entity_helper.h
+++ b/MantleAPI/include/MantleAPI/Traffic/entity_helper.h
@@ -15,23 +15,48 @@
 #ifndef MANTLEAPI_TRAFFIC_ENTITY_HELPER_H
 #define MANTLEAPI_TRAFFIC_ENTITY_HELPER_H
 
+#include <MantleAPI/Common/floating_point_helper.h>
 #include <MantleAPI/Traffic/i_entity.h>
 
 #include <cmath>
 
 namespace mantle_api
 {
-
-inline void SetSpeed(mantle_api::IEntity* entity, const units::velocity::meters_per_second_t &velocity)
+/// Set the speed (i.e. the length of the velocity vector) of an entity. The direction of the velocity vector does not
+/// change. If the entity's speed is zero, then the function uses the orientation of the entity to derive the direction
+/// of the velocity vector.
+///
+/// @param entity   pointer to the entity
+/// @param velocity the new speed to set
+inline void SetSpeed(mantle_api::IEntity* entity, const units::velocity::meters_per_second_t& velocity)
 {
-  auto orientation = entity->GetOrientation();
+  using namespace units::literals;
+
+  if (entity == nullptr)
+  {
+    throw std::runtime_error("entity is null");
+  }
+
+  auto current_velocity = entity->GetVelocity();
+
+  if (IsEqual(current_velocity.Length(), 0_mps))
+  {
+    auto orientation = entity->GetOrientation();
 
-  auto cos_elevation = units::math::cos(orientation.pitch);
-  mantle_api::Vec3<units::velocity::meters_per_second_t> velocity_vector{velocity * units::math::cos(orientation.yaw) * cos_elevation,
-                                    velocity * units::math::sin(orientation.yaw) * cos_elevation,
-                                    velocity * -units::math::sin(orientation.pitch)};
+    auto cos_elevation = units::math::cos(orientation.pitch);
+    mantle_api::Vec3<units::velocity::meters_per_second_t> velocity_vector{
+        velocity * units::math::cos(orientation.yaw) * cos_elevation,
+        velocity * units::math::sin(orientation.yaw) * cos_elevation,
+        velocity * -units::math::sin(orientation.pitch)};
 
-  entity->SetVelocity(velocity_vector);
+    entity->SetVelocity(velocity_vector);
+  }
+  else
+  {
+    mantle_api::Vec3<units::velocity::meters_per_second_t> velocity_vector =
+        (current_velocity / current_velocity.Length()()) * velocity();
+    entity->SetVelocity(velocity_vector);
+  }
 }
 
 }  // namespace mantle_api