1 /*
2  * Copyright 2019, NVIDIA Corporation.
3  * Copyright 2019, Blender Foundation.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  * http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 
18 // clang-format off
19 #include "kernel/kernel_compat_optix.h"
20 #include "util/util_atomic.h"
21 #include "kernel/kernel_types.h"
22 #include "kernel/kernel_globals.h"
23 #include "../cuda/kernel_cuda_image.h"  // Texture lookup uses normal CUDA intrinsics
24 
25 #include "kernel/kernel_path.h"
26 #include "kernel/kernel_bake.h"
27 // clang-format on
28 
get_payload_ptr_0()29 template<typename T> ccl_device_forceinline T *get_payload_ptr_0()
30 {
31   return (T *)(((uint64_t)optixGetPayload_1() << 32) | optixGetPayload_0());
32 }
get_payload_ptr_2()33 template<typename T> ccl_device_forceinline T *get_payload_ptr_2()
34 {
35   return (T *)(((uint64_t)optixGetPayload_3() << 32) | optixGetPayload_2());
36 }
37 
get_object_id()38 template<bool always = false> ccl_device_forceinline uint get_object_id()
39 {
40 #ifdef __OBJECT_MOTION__
41   // Always get the the instance ID from the TLAS
42   // There might be a motion transform node between TLAS and BLAS which does not have one
43   uint object = optixGetInstanceIdFromHandle(optixGetTransformListHandle(0));
44 #else
45   uint object = optixGetInstanceId();
46 #endif
47   // Choose between always returning object ID or only for instances
48   if (always)
49     // Can just remove the high bit since instance always contains object ID
50     return object & 0x7FFFFF;
51   // Set to OBJECT_NONE if this is not an instanced object
52   else if (object & 0x800000)
53     object = OBJECT_NONE;
54   return object;
55 }
56 
__raygen__kernel_optix_path_trace()57 extern "C" __global__ void __raygen__kernel_optix_path_trace()
58 {
59   KernelGlobals kg;  // Allocate stack storage for common data
60 
61   const uint3 launch_index = optixGetLaunchIndex();
62   // Keep threads for same pixel together to improve occupancy of warps
63   uint pixel_offset = launch_index.x / __params.tile.num_samples;
64   uint sample_offset = launch_index.x % __params.tile.num_samples;
65 
66   kernel_path_trace(&kg,
67                     __params.tile.buffer,
68                     __params.tile.start_sample + sample_offset,
69                     __params.tile.x + pixel_offset,
70                     __params.tile.y + launch_index.y,
71                     __params.tile.offset,
72                     __params.tile.stride);
73 }
74 
75 #ifdef __BAKING__
__raygen__kernel_optix_bake()76 extern "C" __global__ void __raygen__kernel_optix_bake()
77 {
78   KernelGlobals kg;
79   const ShaderParams &p = __params.shader;
80   kernel_bake_evaluate(&kg,
81                        p.input,
82                        p.output,
83                        (ShaderEvalType)p.type,
84                        p.filter,
85                        p.sx + optixGetLaunchIndex().x,
86                        p.offset,
87                        p.sample);
88 }
89 #endif
90 
__raygen__kernel_optix_displace()91 extern "C" __global__ void __raygen__kernel_optix_displace()
92 {
93   KernelGlobals kg;
94   const ShaderParams &p = __params.shader;
95   kernel_displace_evaluate(&kg, p.input, p.output, p.sx + optixGetLaunchIndex().x);
96 }
97 
__raygen__kernel_optix_background()98 extern "C" __global__ void __raygen__kernel_optix_background()
99 {
100   KernelGlobals kg;
101   const ShaderParams &p = __params.shader;
102   kernel_background_evaluate(&kg, p.input, p.output, p.sx + optixGetLaunchIndex().x);
103 }
104 
__miss__kernel_optix_miss()105 extern "C" __global__ void __miss__kernel_optix_miss()
106 {
107   // 'kernel_path_lamp_emission' checks intersection distance, so need to set it even on a miss
108   optixSetPayload_0(__float_as_uint(optixGetRayTmax()));
109   optixSetPayload_5(PRIMITIVE_NONE);
110 }
111 
__anyhit__kernel_optix_local_hit()112 extern "C" __global__ void __anyhit__kernel_optix_local_hit()
113 {
114 #ifdef __BVH_LOCAL__
115   const uint object = get_object_id<true>();
116   if (object != optixGetPayload_4() /* local_object */) {
117     // Only intersect with matching object
118     return optixIgnoreIntersection();
119   }
120 
121   int hit = 0;
122   uint *const lcg_state = get_payload_ptr_0<uint>();
123   LocalIntersection *const local_isect = get_payload_ptr_2<LocalIntersection>();
124 
125   if (lcg_state) {
126     const uint max_hits = optixGetPayload_5();
127     for (int i = min(max_hits, local_isect->num_hits) - 1; i >= 0; --i) {
128       if (optixGetRayTmax() == local_isect->hits[i].t) {
129         return optixIgnoreIntersection();
130       }
131     }
132 
133     hit = local_isect->num_hits++;
134 
135     if (local_isect->num_hits > max_hits) {
136       hit = lcg_step_uint(lcg_state) % local_isect->num_hits;
137       if (hit >= max_hits) {
138         return optixIgnoreIntersection();
139       }
140     }
141   }
142   else {
143     if (local_isect->num_hits && optixGetRayTmax() > local_isect->hits[0].t) {
144       // Record closest intersection only
145       // Do not terminate ray here, since there is no guarantee about distance ordering in any-hit
146       return optixIgnoreIntersection();
147     }
148 
149     local_isect->num_hits = 1;
150   }
151 
152   Intersection *isect = &local_isect->hits[hit];
153   isect->t = optixGetRayTmax();
154   isect->prim = optixGetPrimitiveIndex();
155   isect->object = get_object_id();
156   isect->type = kernel_tex_fetch(__prim_type, isect->prim);
157 
158   const float2 barycentrics = optixGetTriangleBarycentrics();
159   isect->u = 1.0f - barycentrics.y - barycentrics.x;
160   isect->v = barycentrics.x;
161 
162   // Record geometric normal
163   const uint tri_vindex = kernel_tex_fetch(__prim_tri_index, isect->prim);
164   const float3 tri_a = float4_to_float3(kernel_tex_fetch(__prim_tri_verts, tri_vindex + 0));
165   const float3 tri_b = float4_to_float3(kernel_tex_fetch(__prim_tri_verts, tri_vindex + 1));
166   const float3 tri_c = float4_to_float3(kernel_tex_fetch(__prim_tri_verts, tri_vindex + 2));
167   local_isect->Ng[hit] = normalize(cross(tri_b - tri_a, tri_c - tri_a));
168 
169   // Continue tracing (without this the trace call would return after the first hit)
170   optixIgnoreIntersection();
171 #endif
172 }
173 
__anyhit__kernel_optix_shadow_all_hit()174 extern "C" __global__ void __anyhit__kernel_optix_shadow_all_hit()
175 {
176 #ifdef __SHADOW_RECORD_ALL__
177   const uint prim = optixGetPrimitiveIndex();
178 #  ifdef __VISIBILITY_FLAG__
179   const uint visibility = optixGetPayload_4();
180   if ((kernel_tex_fetch(__prim_visibility, prim) & visibility) == 0) {
181     return optixIgnoreIntersection();
182   }
183 #  endif
184 
185   // Offset into array with num_hits
186   Intersection *const isect = get_payload_ptr_0<Intersection>() + optixGetPayload_2();
187   isect->t = optixGetRayTmax();
188   isect->prim = prim;
189   isect->object = get_object_id();
190   isect->type = kernel_tex_fetch(__prim_type, prim);
191 
192   if (optixIsTriangleHit()) {
193     const float2 barycentrics = optixGetTriangleBarycentrics();
194     isect->u = 1.0f - barycentrics.y - barycentrics.x;
195     isect->v = barycentrics.x;
196   }
197 #  ifdef __HAIR__
198   else {
199     const float u = __uint_as_float(optixGetAttribute_0());
200     isect->u = u;
201     isect->v = __uint_as_float(optixGetAttribute_1());
202 
203     // Filter out curve endcaps
204     if (u == 0.0f || u == 1.0f) {
205       return optixIgnoreIntersection();
206     }
207   }
208 #  endif
209 
210 #  ifdef __TRANSPARENT_SHADOWS__
211   // Detect if this surface has a shader with transparent shadows
212   if (!shader_transparent_shadow(NULL, isect) || optixGetPayload_2() >= optixGetPayload_3()) {
213 #  endif
214     // This is an opaque hit or the hit limit has been reached, abort traversal
215     optixSetPayload_5(true);
216     return optixTerminateRay();
217 #  ifdef __TRANSPARENT_SHADOWS__
218   }
219 
220   optixSetPayload_2(optixGetPayload_2() + 1);  // num_hits++
221 
222   // Continue tracing
223   optixIgnoreIntersection();
224 #  endif
225 #endif
226 }
227 
__anyhit__kernel_optix_visibility_test()228 extern "C" __global__ void __anyhit__kernel_optix_visibility_test()
229 {
230   uint visibility = optixGetPayload_4();
231 #ifdef __VISIBILITY_FLAG__
232   const uint prim = optixGetPrimitiveIndex();
233   if ((kernel_tex_fetch(__prim_visibility, prim) & visibility) == 0) {
234     return optixIgnoreIntersection();
235   }
236 #endif
237 
238 #ifdef __HAIR__
239   if (!optixIsTriangleHit()) {
240     // Filter out curve endcaps
241     const float u = __uint_as_float(optixGetAttribute_0());
242     if (u == 0.0f || u == 1.0f) {
243       return optixIgnoreIntersection();
244     }
245   }
246 #endif
247 
248   // Shadow ray early termination
249   if (visibility & PATH_RAY_SHADOW_OPAQUE) {
250     return optixTerminateRay();
251   }
252 }
253 
__closesthit__kernel_optix_hit()254 extern "C" __global__ void __closesthit__kernel_optix_hit()
255 {
256   optixSetPayload_0(__float_as_uint(optixGetRayTmax()));  // Intersection distance
257   optixSetPayload_3(optixGetPrimitiveIndex());
258   optixSetPayload_4(get_object_id());
259   // Can be PRIMITIVE_TRIANGLE and PRIMITIVE_MOTION_TRIANGLE or curve type and segment index
260   optixSetPayload_5(kernel_tex_fetch(__prim_type, optixGetPrimitiveIndex()));
261 
262   if (optixIsTriangleHit()) {
263     const float2 barycentrics = optixGetTriangleBarycentrics();
264     optixSetPayload_1(__float_as_uint(1.0f - barycentrics.y - barycentrics.x));
265     optixSetPayload_2(__float_as_uint(barycentrics.x));
266   }
267   else {
268     optixSetPayload_1(optixGetAttribute_0());  // Same as 'optixGetCurveParameter()'
269     optixSetPayload_2(optixGetAttribute_1());
270   }
271 }
272 
273 #ifdef __HAIR__
optix_intersection_curve(const uint prim,const uint type)274 ccl_device_inline void optix_intersection_curve(const uint prim, const uint type)
275 {
276   const uint object = get_object_id<true>();
277   const uint visibility = optixGetPayload_4();
278 
279   float3 P = optixGetObjectRayOrigin();
280   float3 dir = optixGetObjectRayDirection();
281 
282   // The direction is not normalized by default, but the curve intersection routine expects that
283   float len;
284   dir = normalize_len(dir, &len);
285 
286 #  ifdef __OBJECT_MOTION__
287   const float time = optixGetRayTime();
288 #  else
289   const float time = 0.0f;
290 #  endif
291 
292   Intersection isect;
293   isect.t = optixGetRayTmax();
294   // Transform maximum distance into object space
295   if (isect.t != FLT_MAX)
296     isect.t *= len;
297 
298   if (curve_intersect(NULL, &isect, P, dir, visibility, object, prim, time, type)) {
299     optixReportIntersection(isect.t / len,
300                             type & PRIMITIVE_ALL,
301                             __float_as_int(isect.u),   // Attribute_0
302                             __float_as_int(isect.v));  // Attribute_1
303   }
304 }
305 
__intersection__curve_ribbon()306 extern "C" __global__ void __intersection__curve_ribbon()
307 {
308   const uint prim = optixGetPrimitiveIndex();
309   const uint type = kernel_tex_fetch(__prim_type, prim);
310 
311   if (type & (PRIMITIVE_CURVE_RIBBON | PRIMITIVE_MOTION_CURVE_RIBBON)) {
312     optix_intersection_curve(prim, type);
313   }
314 }
315 
__intersection__curve_all()316 extern "C" __global__ void __intersection__curve_all()
317 {
318   const uint prim = optixGetPrimitiveIndex();
319   const uint type = kernel_tex_fetch(__prim_type, prim);
320   optix_intersection_curve(prim, type);
321 }
322 #endif
323 
324 #ifdef __KERNEL_DEBUG__
__exception__kernel_optix_exception()325 extern "C" __global__ void __exception__kernel_optix_exception()
326 {
327   printf("Unhandled exception occured: code %d!\n", optixGetExceptionCode());
328 }
329 #endif
330