1 // Copyright Contributors to the Open Shading Language project.
2 // SPDX-License-Identifier: BSD-3-Clause
3 // https://github.com/AcademySoftwareFoundation/OpenShadingLanguage
4 
5 
6 #include <optix.h>
7 
8 #if (OPTIX_VERSION < 70000)
9 #include <optixu/optixu_aabb_namespace.h>
10 #include <optixu/optixu_math_namespace.h>
11 #include <optixu/optixu_vector_types.h>
12 
13 
14 using namespace optix;
15 
16 rtDeclareVariable (float3, p,  , );
17 rtDeclareVariable (float3, ex, , );
18 rtDeclareVariable (float3, ey, , );
19 rtDeclareVariable (float3, n,  , );
20 rtDeclareVariable (float,  eu, , );
21 rtDeclareVariable (float,  ev, , );
22 rtDeclareVariable (float,  a, ,  );
23 
24 rtDeclareVariable (float3, texcoord,         attribute texcoord, );
25 rtDeclareVariable (float3, geometric_normal, attribute geometric_normal, );
26 rtDeclareVariable (float3, shading_normal,   attribute shading_normal, );
27 rtDeclareVariable (float,  surface_area,     attribute surface_area, );
28 
29 rtDeclareVariable (float3, dPdu, attribute dPdu, );
30 rtDeclareVariable (float3, dPdv, attribute dPdv, );
31 
32 rtDeclareVariable (optix::Ray, ray, rtCurrentRay, );
33 
34 
intersect(void)35 RT_PROGRAM void intersect (void)
36 {
37     float dn = dot(ray.direction, n);
38     float en = dot(p - ray.origin, n);
39     if (dn * en > 0) {
40         float  t  = en / dn;
41         float3 h  = (ray.origin + ray.direction * t) - p;
42         float  dx = dot(h, ex) * eu;
43         float  dy = dot(h, ey) * ev;
44 
45         if (dx >= 0 && dx < 1.0f && dy >= 0 && dy < 1.0f && rtPotentialIntersection(t)) {
46             shading_normal = geometric_normal = n;
47             texcoord = make_float3(dot (h, ex) * eu, dot (h, ey) * ev, 0.0f);
48             dPdu = ey;
49             dPdv = ex;
50             surface_area = a;
51             rtReportIntersection(0);
52         }
53     }
54 }
55 
56 
bounds(int,float result[6])57 RT_PROGRAM void bounds (int, float result[6])
58 {
59     const float3 p00  = p;
60     const float3 p01  = p + ex;
61     const float3 p10  = p + ey;
62     const float3 p11  = p + ex + ey;
63     const float  area = length(cross(ex, ey));
64 
65     optix::Aabb* aabb = reinterpret_cast<optix::Aabb*>(result);
66 
67     if (area > 0.0f && !isinf(area)) {
68         aabb->m_min = fminf (fminf (p00, p01), fminf (p10, p11));
69         aabb->m_max = fmaxf (fmaxf (p00, p01), fmaxf (p10, p11));
70     } else {
71         aabb->invalidate();
72     }
73 }
74 
75 #else //#if (OPTIX_VERSION < 70000)
76 
77 #include "wrapper.h"
78 #include "rend_lib.h"
79 #include "render_params.h"
80 
81 extern "C" __device__
__direct_callable__quad_shaderglobals(const unsigned int idx,const float t_hit,const float3 ray_origin,const float3 ray_direction,ShaderGlobals * sg)82 void __direct_callable__quad_shaderglobals (const unsigned int idx,
83                                             const float        t_hit,
84                                             const float3       ray_origin,
85                                             const float3       ray_direction,
86                                             ShaderGlobals     *sg)
87 {
88     const GenericData *g_data  = reinterpret_cast<const GenericData *>(optixGetSbtDataPointer());
89     const QuadParams *g_quads  = reinterpret_cast<const QuadParams *>(g_data->data);
90     const QuadParams &quad     = g_quads[idx];
91     const float3 P = ray_origin + t_hit * ray_direction;
92 
93     float3 h  = P - quad.p;
94 
95     sg->N = sg->Ng = quad.n;
96     sg->u    = dot (h, quad.ex) * quad.eu;
97     sg->v    = dot (h, quad.ey) * quad.ev;
98     sg->dPdu = quad.ey;
99     sg->dPdv = quad.ex;
100     sg->surfacearea = quad.a;
101     sg->shaderID    = quad.shaderID;
102 }
103 
104 
105 extern "C" __global__
__intersection__quad()106 void __intersection__quad ()
107 {
108     const GenericData *g_data  = reinterpret_cast<const GenericData *>(optixGetSbtDataPointer());
109     const QuadParams *g_quads  = reinterpret_cast<const QuadParams *>(g_data->data);
110     const unsigned int idx     = optixGetPrimitiveIndex();
111     const QuadParams &quad     = g_quads[idx];
112     const float3 ray_origin    = optixGetObjectRayOrigin();
113     const float3 ray_direction = optixGetObjectRayDirection();
114 
115     float dn = dot(ray_direction, quad.n);
116     float en = dot(quad.p - ray_origin, quad.n);
117     if (dn * en > 0) {
118         float  t  = en / dn;
119         float3 h  = (ray_origin + ray_direction * t) - quad.p;
120         float  dx = dot(h, quad.ex) * quad.eu;
121         float  dy = dot(h, quad.ey) * quad.ev;
122 
123         if (dx >= 0 && dx < 1.0f && dy >= 0 && dy < 1.0f && t < optixGetRayTmax())
124             optixReportIntersection (t, RAYTRACER_HIT_QUAD);
125     }
126 }
127 
128 #endif //#if (OPTIX_VERSION < 70000)
129