1 // =============================================================================
2 // PROJECT CHRONO - http://projectchrono.org
3 //
4 // Copyright (c) 2019 projectchrono.org
5 // All rights reserved.
6 //
7 // Use of this source code is governed by a BSD-style license that can be found
8 // in the LICENSE file at the top level of the distribution and at
9 // http://projectchrono.org/license-chrono.txt.
10 //
11 // =============================================================================
12 // Authors: Asher Elmquist
13 // =============================================================================
14 //
15 // RT kernels for sphere geometries
16 //
17 // =============================================================================
18 #ifdef _WIN32
19 #ifndef NOMINMAX
20 #define NOMINMAX
21 #endif
22 #endif
23 
24 #include "chrono_sensor/optix/shaders/device_utils.h"
25 
check_ends(const float3 & ray_orig,const float3 & ray_dir,const float & ray_tmin,const float & ray_tmax)26 __device__ __inline__ void check_ends(const float3& ray_orig,
27                                       const float3& ray_dir,
28                                       const float& ray_tmin,
29                                       const float& ray_tmax) {
30     float t = (.5 - ray_orig.y) / ray_dir.y;
31     float3 p = ray_orig + ray_dir * t;
32 
33     if (p.x * p.x + p.z * p.z - 1 < 0 && t > ray_tmin && t < ray_tmax) {
34         float3 shading_normal = make_float3(0, 1, 0);
35         float3 tangent_vector = make_float3(-1, 0, 0);
36         float2 texcoord = make_float2(-p.x / 2.f + .5f, -p.z / 2.f + .5f);
37         optixReportIntersection(
38             t,                                                  //
39             0,                                                  //
40             reinterpret_cast<unsigned int&>(shading_normal.x),  //
41             reinterpret_cast<unsigned int&>(shading_normal.y),  //
42             reinterpret_cast<unsigned int&>(shading_normal.z),  //
43             reinterpret_cast<unsigned int&>(texcoord.x), reinterpret_cast<unsigned int&>(texcoord.y),
44             reinterpret_cast<unsigned int&>(tangent_vector.x), reinterpret_cast<unsigned int&>(tangent_vector.y),
45             reinterpret_cast<unsigned int&>(tangent_vector.z));
46     }
47 
48     t = (-.5 - ray_orig.y) / ray_dir.y;
49     p = ray_orig + ray_dir * t;
50     if (p.x * p.x + p.z * p.z - 1 < 0 && t > ray_tmin && t < ray_tmax) {
51         float3 shading_normal = make_float3(0, 1, 0);
52         float3 tangent_vector = make_float3(1, 0, 0);
53         float2 texcoord = make_float2(p.x / 2.f + .5f, p.z / 2.f + .5f);
54         optixReportIntersection(
55             t,                                                  //
56             0,                                                  //
57             reinterpret_cast<unsigned int&>(shading_normal.x),  //
58             reinterpret_cast<unsigned int&>(shading_normal.y),  //
59             reinterpret_cast<unsigned int&>(shading_normal.z),  //
60             reinterpret_cast<unsigned int&>(texcoord.x), reinterpret_cast<unsigned int&>(texcoord.y),
61             reinterpret_cast<unsigned int&>(tangent_vector.x), reinterpret_cast<unsigned int&>(tangent_vector.y),
62             reinterpret_cast<unsigned int&>(tangent_vector.z));
63     }
64 }
65 
__intersection__cylinder_intersect()66 extern "C" __global__ void __intersection__cylinder_intersect() {
67     const float3 ray_orig = optixGetObjectRayOrigin();
68     const float3 ray_dir = optixGetObjectRayDirection();
69     const float ray_tmin = optixGetRayTmin();
70     const float ray_tmax = optixGetRayTmax();
71 
72     check_ends(ray_orig, ray_dir, ray_tmin, ray_tmax);
73 
74     float a = ray_dir.x * ray_dir.x + ray_dir.z * ray_dir.z;
75     float b = 2 * (ray_dir.x * ray_orig.x + ray_dir.z * ray_orig.z);
76     float c = ray_orig.x * ray_orig.x + ray_orig.z * ray_orig.z - 1;
77     float det = b * b - 4 * a * c;
78 
79     if (det > 0) {
80         const float dist_near = (-b - sqrtf(det)) / (2 * a);
81         const float dist_far = (-b + sqrtf(det)) / (2 * a);
82 
83         if (dist_near <= dist_far) {
84             const float3 p_near = ray_orig + ray_dir * dist_near;
85             const float3 p_far = ray_orig + ray_dir * dist_far;
86 
87             if (dist_near > ray_tmin && dist_near < ray_tmax && p_near.y < .5 && p_near.y > -.5) {
88                 float3 shading_normal = p_near - make_float3(0, p_near.y, 0);
89                 float3 tangent_vector = make_float3(p_near.z, 0, -p_near.x);
90                 float2 texcoord = make_float2(atan2(p_near.x, p_near.z) / (2 * CUDART_PI_F), p_near.y * 0.5 + 0.5);
91                 optixReportIntersection(dist_near,                                          //
92                                         0,                                                  //
93                                         reinterpret_cast<unsigned int&>(shading_normal.x),  //
94                                         reinterpret_cast<unsigned int&>(shading_normal.y),  //
95                                         reinterpret_cast<unsigned int&>(shading_normal.z),  //
96                                         reinterpret_cast<unsigned int&>(texcoord.x),
97                                         reinterpret_cast<unsigned int&>(texcoord.y),
98                                         reinterpret_cast<unsigned int&>(tangent_vector.x),
99                                         reinterpret_cast<unsigned int&>(tangent_vector.y),
100                                         reinterpret_cast<unsigned int&>(tangent_vector.z));
101             } else if (dist_far > ray_tmin && dist_far < ray_tmax && p_far.y < .5 && p_far.y > -.5) {
102                 float3 shading_normal = p_far - make_float3(0, p_far.y, 0);
103                 float3 tangent_vector = make_float3(p_far.z, 0, -p_far.x);
104                 float2 texcoord = make_float2(atan2(p_far.x, p_far.z) / (2 * CUDART_PI_F), p_far.y * 0.5 + 0.5);
105                 optixReportIntersection(dist_far,                                           //
106                                         0,                                                  //
107                                         reinterpret_cast<unsigned int&>(shading_normal.x),  //
108                                         reinterpret_cast<unsigned int&>(shading_normal.y),  //
109                                         reinterpret_cast<unsigned int&>(shading_normal.z),  //
110                                         reinterpret_cast<unsigned int&>(texcoord.x),
111                                         reinterpret_cast<unsigned int&>(texcoord.y),
112                                         reinterpret_cast<unsigned int&>(tangent_vector.x),
113                                         reinterpret_cast<unsigned int&>(tangent_vector.y),
114                                         reinterpret_cast<unsigned int&>(tangent_vector.z));
115             }
116         }
117     }
118 }
119